From 539e90bd72e0f0735a5eee497064bf9aeca257f9 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 15 Oct 2024 16:48:53 -0500 Subject: [PATCH 01/74] checkpoint --- ecc/bls12-377/fr/element.go | 20 -- ecc/bls12-377/fr/element_ops_arm64.go | 28 ++ ecc/bls12-377/fr/element_ops_arm64.s | 6 + ecc/bls12-381/fr/element.go | 20 -- ecc/bls12-381/fr/element_ops_arm64.go | 28 ++ ecc/bls12-381/fr/element_ops_arm64.s | 6 + ecc/bls24-315/fr/element.go | 20 -- ecc/bls24-315/fr/element_ops_arm64.go | 28 ++ ecc/bls24-315/fr/element_ops_arm64.s | 6 + ecc/bls24-317/fr/element.go | 20 -- ecc/bls24-317/fr/element_ops_arm64.go | 28 ++ ecc/bls24-317/fr/element_ops_arm64.s | 6 + ecc/bn254/fp/element.go | 20 -- ecc/bn254/fp/element_ops_arm64.go | 28 ++ ecc/bn254/fp/element_ops_arm64.s | 6 + ecc/bn254/fr/element.go | 20 -- ecc/bn254/fr/element_ops_arm64.go | 28 ++ ecc/bn254/fr/element_ops_arm64.s | 6 + ecc/stark-curve/fp/element.go | 20 -- ecc/stark-curve/fp/element_ops_arm64.go | 28 ++ ecc/stark-curve/fp/element_ops_arm64.s | 6 + ecc/stark-curve/fr/element.go | 20 -- ecc/stark-curve/fr/element_ops_arm64.go | 28 ++ ecc/stark-curve/fr/element_ops_arm64.s | 6 + field/asm/element_4w_arm64.s | 131 ++++++++ field/generator/asm/amd64/build.go | 7 - field/generator/asm/arm64/build.go | 287 +++++++++++++++++ field/generator/asm/arm64/element_ops.go | 292 ++++++++++++++++++ field/generator/config/field_config.go | 2 + field/generator/generator.go | 109 ++++++- field/generator/generator_test.go | 8 +- .../internal/templates/element/base.go | 2 + .../internal/templates/element/ops_asm.go | 12 + go.mod | 4 +- go.sum | 2 - internal/generator/main.go | 9 +- 36 files changed, 1105 insertions(+), 192 deletions(-) create mode 100644 ecc/bls12-377/fr/element_ops_arm64.go create mode 100644 ecc/bls12-377/fr/element_ops_arm64.s create mode 100644 ecc/bls12-381/fr/element_ops_arm64.go create mode 100644 ecc/bls12-381/fr/element_ops_arm64.s create mode 100644 ecc/bls24-315/fr/element_ops_arm64.go create mode 100644 ecc/bls24-315/fr/element_ops_arm64.s create mode 100644 ecc/bls24-317/fr/element_ops_arm64.go create mode 100644 ecc/bls24-317/fr/element_ops_arm64.s create mode 100644 ecc/bn254/fp/element_ops_arm64.go create mode 100644 ecc/bn254/fp/element_ops_arm64.s create mode 100644 ecc/bn254/fr/element_ops_arm64.go create mode 100644 ecc/bn254/fr/element_ops_arm64.s create mode 100644 ecc/stark-curve/fp/element_ops_arm64.go create mode 100644 ecc/stark-curve/fp/element_ops_arm64.s create mode 100644 ecc/stark-curve/fr/element_ops_arm64.go create mode 100644 ecc/stark-curve/fr/element_ops_arm64.s create mode 100644 field/asm/element_4w_arm64.s create mode 100644 field/generator/asm/arm64/build.go create mode 100644 field/generator/asm/arm64/element_ops.go diff --git a/ecc/bls12-377/fr/element.go b/ecc/bls12-377/fr/element.go index af277e8bb1..5e34cbd2ba 100644 --- a/ecc/bls12-377/fr/element.go +++ b/ecc/bls12-377/fr/element.go @@ -393,26 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Add z = x + y (mod q) -func (z *Element) Add(x, y *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], y[0], 0) - z[1], carry = bits.Add64(x[1], y[1], carry) - z[2], carry = bits.Add64(x[2], y[2], carry) - z[3], _ = bits.Add64(x[3], y[3], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Double z = x + x (mod q), aka Lsh 1 func (z *Element) Double(x *Element) *Element { diff --git a/ecc/bls12-377/fr/element_ops_arm64.go b/ecc/bls12-377/fr/element_ops_arm64.go new file mode 100644 index 0000000000..1591b31d62 --- /dev/null +++ b/ecc/bls12-377/fr/element_ops_arm64.go @@ -0,0 +1,28 @@ +//go:build !purego +// +build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +//go:noescape +func add(res, x, y *Element) + +func (z *Element) Add(x, y *Element) *Element { + add(z, x, y) + return z +} diff --git a/ecc/bls12-377/fr/element_ops_arm64.s b/ecc/bls12-377/fr/element_ops_arm64.s new file mode 100644 index 0000000000..5d79f98f87 --- /dev/null +++ b/ecc/bls12-377/fr/element_ops_arm64.s @@ -0,0 +1,6 @@ +// +build !purego + +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 17172654935612186478 +#include "../../../field/asm/element_4w_arm64.s" + diff --git a/ecc/bls12-381/fr/element.go b/ecc/bls12-381/fr/element.go index dc38f08cd3..8cf8b8bd78 100644 --- a/ecc/bls12-381/fr/element.go +++ b/ecc/bls12-381/fr/element.go @@ -393,26 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Add z = x + y (mod q) -func (z *Element) Add(x, y *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], y[0], 0) - z[1], carry = bits.Add64(x[1], y[1], carry) - z[2], carry = bits.Add64(x[2], y[2], carry) - z[3], _ = bits.Add64(x[3], y[3], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Double z = x + x (mod q), aka Lsh 1 func (z *Element) Double(x *Element) *Element { diff --git a/ecc/bls12-381/fr/element_ops_arm64.go b/ecc/bls12-381/fr/element_ops_arm64.go new file mode 100644 index 0000000000..1591b31d62 --- /dev/null +++ b/ecc/bls12-381/fr/element_ops_arm64.go @@ -0,0 +1,28 @@ +//go:build !purego +// +build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +//go:noescape +func add(res, x, y *Element) + +func (z *Element) Add(x, y *Element) *Element { + add(z, x, y) + return z +} diff --git a/ecc/bls12-381/fr/element_ops_arm64.s b/ecc/bls12-381/fr/element_ops_arm64.s new file mode 100644 index 0000000000..5d79f98f87 --- /dev/null +++ b/ecc/bls12-381/fr/element_ops_arm64.s @@ -0,0 +1,6 @@ +// +build !purego + +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 17172654935612186478 +#include "../../../field/asm/element_4w_arm64.s" + diff --git a/ecc/bls24-315/fr/element.go b/ecc/bls24-315/fr/element.go index abdb822acf..cae603ade5 100644 --- a/ecc/bls24-315/fr/element.go +++ b/ecc/bls24-315/fr/element.go @@ -393,26 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Add z = x + y (mod q) -func (z *Element) Add(x, y *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], y[0], 0) - z[1], carry = bits.Add64(x[1], y[1], carry) - z[2], carry = bits.Add64(x[2], y[2], carry) - z[3], _ = bits.Add64(x[3], y[3], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Double z = x + x (mod q), aka Lsh 1 func (z *Element) Double(x *Element) *Element { diff --git a/ecc/bls24-315/fr/element_ops_arm64.go b/ecc/bls24-315/fr/element_ops_arm64.go new file mode 100644 index 0000000000..1591b31d62 --- /dev/null +++ b/ecc/bls24-315/fr/element_ops_arm64.go @@ -0,0 +1,28 @@ +//go:build !purego +// +build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +//go:noescape +func add(res, x, y *Element) + +func (z *Element) Add(x, y *Element) *Element { + add(z, x, y) + return z +} diff --git a/ecc/bls24-315/fr/element_ops_arm64.s b/ecc/bls24-315/fr/element_ops_arm64.s new file mode 100644 index 0000000000..5d79f98f87 --- /dev/null +++ b/ecc/bls24-315/fr/element_ops_arm64.s @@ -0,0 +1,6 @@ +// +build !purego + +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 17172654935612186478 +#include "../../../field/asm/element_4w_arm64.s" + diff --git a/ecc/bls24-317/fr/element.go b/ecc/bls24-317/fr/element.go index 3aefaebe62..790491630c 100644 --- a/ecc/bls24-317/fr/element.go +++ b/ecc/bls24-317/fr/element.go @@ -393,26 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Add z = x + y (mod q) -func (z *Element) Add(x, y *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], y[0], 0) - z[1], carry = bits.Add64(x[1], y[1], carry) - z[2], carry = bits.Add64(x[2], y[2], carry) - z[3], _ = bits.Add64(x[3], y[3], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Double z = x + x (mod q), aka Lsh 1 func (z *Element) Double(x *Element) *Element { diff --git a/ecc/bls24-317/fr/element_ops_arm64.go b/ecc/bls24-317/fr/element_ops_arm64.go new file mode 100644 index 0000000000..1591b31d62 --- /dev/null +++ b/ecc/bls24-317/fr/element_ops_arm64.go @@ -0,0 +1,28 @@ +//go:build !purego +// +build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +//go:noescape +func add(res, x, y *Element) + +func (z *Element) Add(x, y *Element) *Element { + add(z, x, y) + return z +} diff --git a/ecc/bls24-317/fr/element_ops_arm64.s b/ecc/bls24-317/fr/element_ops_arm64.s new file mode 100644 index 0000000000..5d79f98f87 --- /dev/null +++ b/ecc/bls24-317/fr/element_ops_arm64.s @@ -0,0 +1,6 @@ +// +build !purego + +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 17172654935612186478 +#include "../../../field/asm/element_4w_arm64.s" + diff --git a/ecc/bn254/fp/element.go b/ecc/bn254/fp/element.go index 25fcdb67cc..87323af2bc 100644 --- a/ecc/bn254/fp/element.go +++ b/ecc/bn254/fp/element.go @@ -393,26 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Add z = x + y (mod q) -func (z *Element) Add(x, y *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], y[0], 0) - z[1], carry = bits.Add64(x[1], y[1], carry) - z[2], carry = bits.Add64(x[2], y[2], carry) - z[3], _ = bits.Add64(x[3], y[3], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Double z = x + x (mod q), aka Lsh 1 func (z *Element) Double(x *Element) *Element { diff --git a/ecc/bn254/fp/element_ops_arm64.go b/ecc/bn254/fp/element_ops_arm64.go new file mode 100644 index 0000000000..c7a28f43c7 --- /dev/null +++ b/ecc/bn254/fp/element_ops_arm64.go @@ -0,0 +1,28 @@ +//go:build !purego +// +build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +//go:noescape +func add(res, x, y *Element) + +func (z *Element) Add(x, y *Element) *Element { + add(z, x, y) + return z +} diff --git a/ecc/bn254/fp/element_ops_arm64.s b/ecc/bn254/fp/element_ops_arm64.s new file mode 100644 index 0000000000..5d79f98f87 --- /dev/null +++ b/ecc/bn254/fp/element_ops_arm64.s @@ -0,0 +1,6 @@ +// +build !purego + +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 17172654935612186478 +#include "../../../field/asm/element_4w_arm64.s" + diff --git a/ecc/bn254/fr/element.go b/ecc/bn254/fr/element.go index 3650c954c5..3da98a1b57 100644 --- a/ecc/bn254/fr/element.go +++ b/ecc/bn254/fr/element.go @@ -393,26 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Add z = x + y (mod q) -func (z *Element) Add(x, y *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], y[0], 0) - z[1], carry = bits.Add64(x[1], y[1], carry) - z[2], carry = bits.Add64(x[2], y[2], carry) - z[3], _ = bits.Add64(x[3], y[3], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Double z = x + x (mod q), aka Lsh 1 func (z *Element) Double(x *Element) *Element { diff --git a/ecc/bn254/fr/element_ops_arm64.go b/ecc/bn254/fr/element_ops_arm64.go new file mode 100644 index 0000000000..1591b31d62 --- /dev/null +++ b/ecc/bn254/fr/element_ops_arm64.go @@ -0,0 +1,28 @@ +//go:build !purego +// +build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +//go:noescape +func add(res, x, y *Element) + +func (z *Element) Add(x, y *Element) *Element { + add(z, x, y) + return z +} diff --git a/ecc/bn254/fr/element_ops_arm64.s b/ecc/bn254/fr/element_ops_arm64.s new file mode 100644 index 0000000000..5d79f98f87 --- /dev/null +++ b/ecc/bn254/fr/element_ops_arm64.s @@ -0,0 +1,6 @@ +// +build !purego + +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 17172654935612186478 +#include "../../../field/asm/element_4w_arm64.s" + diff --git a/ecc/stark-curve/fp/element.go b/ecc/stark-curve/fp/element.go index 1c53dcb090..7a68362d40 100644 --- a/ecc/stark-curve/fp/element.go +++ b/ecc/stark-curve/fp/element.go @@ -393,26 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Add z = x + y (mod q) -func (z *Element) Add(x, y *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], y[0], 0) - z[1], carry = bits.Add64(x[1], y[1], carry) - z[2], carry = bits.Add64(x[2], y[2], carry) - z[3], _ = bits.Add64(x[3], y[3], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Double z = x + x (mod q), aka Lsh 1 func (z *Element) Double(x *Element) *Element { diff --git a/ecc/stark-curve/fp/element_ops_arm64.go b/ecc/stark-curve/fp/element_ops_arm64.go new file mode 100644 index 0000000000..c7a28f43c7 --- /dev/null +++ b/ecc/stark-curve/fp/element_ops_arm64.go @@ -0,0 +1,28 @@ +//go:build !purego +// +build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +//go:noescape +func add(res, x, y *Element) + +func (z *Element) Add(x, y *Element) *Element { + add(z, x, y) + return z +} diff --git a/ecc/stark-curve/fp/element_ops_arm64.s b/ecc/stark-curve/fp/element_ops_arm64.s new file mode 100644 index 0000000000..5d79f98f87 --- /dev/null +++ b/ecc/stark-curve/fp/element_ops_arm64.s @@ -0,0 +1,6 @@ +// +build !purego + +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 17172654935612186478 +#include "../../../field/asm/element_4w_arm64.s" + diff --git a/ecc/stark-curve/fr/element.go b/ecc/stark-curve/fr/element.go index 216e287ebb..601d7ba831 100644 --- a/ecc/stark-curve/fr/element.go +++ b/ecc/stark-curve/fr/element.go @@ -393,26 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Add z = x + y (mod q) -func (z *Element) Add(x, y *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], y[0], 0) - z[1], carry = bits.Add64(x[1], y[1], carry) - z[2], carry = bits.Add64(x[2], y[2], carry) - z[3], _ = bits.Add64(x[3], y[3], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Double z = x + x (mod q), aka Lsh 1 func (z *Element) Double(x *Element) *Element { diff --git a/ecc/stark-curve/fr/element_ops_arm64.go b/ecc/stark-curve/fr/element_ops_arm64.go new file mode 100644 index 0000000000..1591b31d62 --- /dev/null +++ b/ecc/stark-curve/fr/element_ops_arm64.go @@ -0,0 +1,28 @@ +//go:build !purego +// +build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +//go:noescape +func add(res, x, y *Element) + +func (z *Element) Add(x, y *Element) *Element { + add(z, x, y) + return z +} diff --git a/ecc/stark-curve/fr/element_ops_arm64.s b/ecc/stark-curve/fr/element_ops_arm64.s new file mode 100644 index 0000000000..5d79f98f87 --- /dev/null +++ b/ecc/stark-curve/fr/element_ops_arm64.s @@ -0,0 +1,6 @@ +// +build !purego + +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 17172654935612186478 +#include "../../../field/asm/element_4w_arm64.s" + diff --git a/field/asm/element_4w_arm64.s b/field/asm/element_4w_arm64.s new file mode 100644 index 0000000000..6edffb51cf --- /dev/null +++ b/field/asm/element_4w_arm64.s @@ -0,0 +1,131 @@ +// Code generated by gnark-crypto/generator. DO NOT EDIT. +#include "textflag.h" +#include "funcdata.h" +#include "go_asm.h" + +// add(res, x, y *Element) +TEXT ·add(SB), NOSPLIT, $0-24 + LDP x+8(FP), (R4, R5) + + // load operands and add mod 2^r + LDP 0(R4), (R0, R6) + LDP 0(R5), (R1, R7) + ADDS R0, R1, R0 + ADCS R6, R7, R1 + LDP 16(R4), (R2, R6) + LDP 16(R5), (R3, R7) + ADCS R2, R3, R2 + ADCS R6, R7, R3 + + // load modulus and subtract + LDP ·qElement+0(SB), (R8, R9) + SUBS R8, R0, R8 + SBCS R9, R1, R9 + LDP ·qElement+16(SB), (R10, R11) + SBCS R10, R2, R10 + SBCS R11, R3, R11 + + // reduce if necessary + CSEL CS, R8, R0, R0 + CSEL CS, R9, R1, R1 + CSEL CS, R10, R2, R2 + CSEL CS, R11, R3, R3 + + // store + MOVD res+0(FP), R12 + +#define STOREVECTOR(in0, in1, in2, in3, in4) \ + STP (in1, in2), 0(in0) \ + STP (in3, in4), 128(in0) \ + + STOREVECTOR(R12, R0, R1, R2, R3) + RET + +// sub(res, x, y *Element) +TEXT ·sub(SB), NOSPLIT, $0-24 + LDP x+8(FP), (R4, R5) + + // load operands and subtract mod 2^r + LDP 0(R4), (R0, R6) + LDP 0(R5), (R1, R7) + SUBS R1, R0, R0 + SBCS R7, R6, R1 + LDP 16(R4), (R2, R6) + LDP 16(R5), (R3, R7) + SBCS R3, R2, R2 + SBCS R7, R6, R3 + + // load modulus and select + MOVD $0, R12 + LDP ·qElement+0(SB), (R8, R9) + CSEL CS, R12, R8, R8 + CSEL CS, R12, R9, R9 + LDP ·qElement+16(SB), (R10, R11) + CSEL CS, R12, R10, R10 + CSEL CS, R12, R11, R11 + + // augment (or not) + ADDS R0, R8, R0 + ADCS R1, R9, R1 + ADCS R2, R10, R2 + ADCS R3, R11, R3 + + // store + MOVD res+0(FP), R13 + RET + +// double(res, x *Element) +TEXT ·double(SB), NOSPLIT, $0-16 + LDP res+0(FP), (R5, R4) + + // load operands and add mod 2^r + LDP 0(R4), (R0, R1) + ADDS R0, R0, R0 + ADCS R1, R1, R1 + LDP 16(R4), (R2, R3) + ADCS R2, R2, R2 + ADCS R3, R3, R3 + + // load modulus and subtract + LDP ·qElement+0(SB), (R6, R7) + SUBS R6, R0, R6 + SBCS R7, R1, R7 + LDP ·qElement+16(SB), (R8, R9) + SBCS R8, R2, R8 + SBCS R9, R3, R9 + + // reduce if necessary + CSEL CS, R6, R0, R0 + CSEL CS, R7, R1, R1 + CSEL CS, R8, R2, R2 + CSEL CS, R9, R3, R3 + + // store + RET + +// neg(res, x *Element) +TEXT ·neg(SB), NOSPLIT, $0-16 + LDP res+0(FP), (R5, R4) + + // load operands and subtract + MOVD $0, R8 + LDP 0(R4), (R0, R1) + LDP ·qElement+0(SB), (R6, R7) + ORR R0, R8, R8 // has x been 0 so far? + ORR R1, R8, R8 + SUBS R0, R6, R0 + SBCS R1, R7, R1 + LDP 16(R4), (R2, R3) + LDP ·qElement+16(SB), (R6, R7) + ORR R2, R8, R8 // has x been 0 so far? + ORR R3, R8, R8 + SBCS R2, R6, R2 + SBCS R3, R7, R3 + TST $0xffffffffffffffff, R8 + CSEL EQ, R8, R0, R0 + CSEL EQ, R8, R1, R1 + CSEL EQ, R8, R2, R2 + CSEL EQ, R8, R3, R3 + + // store + RET diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index 9f2e44bd1a..40edcfc366 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -137,13 +137,6 @@ startDefine: return toReturn } -func max(a, b int) int { - if a > b { - return a - } - return b -} - func (f *FFAmd64) AssertCleanStack(reservedStackSize, minStackSize int) { if f.nbElementsOnStack != 0 { panic("missing f.Push stack elements") diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go new file mode 100644 index 0000000000..2b0a628a8d --- /dev/null +++ b/field/generator/asm/arm64/build.go @@ -0,0 +1,287 @@ +package arm64 + +import ( + "fmt" + "hash/fnv" + "io" + "os" + "path/filepath" + "strings" + + "github.com/consensys/bavard/arm64" + "github.com/consensys/gnark-crypto/field/generator/config" +) + +const ( + ElementASMFileName = "element_%dw_arm64.s" +) + +type defineFn func(args ...arm64.Register) + +func NewFFArm64(w io.Writer, nbWords int) *FFArm64 { + F := &FFArm64{ + arm64.NewArm64(w), + 0, + 0, + nbWords, + nbWords - 1, + make([]int, nbWords), + make([]int, nbWords-1), + make(map[string]defineFn), + } + + // indexes (template helpers) + for i := 0; i < F.NbWords; i++ { + F.NbWordsIndexesFull[i] = i + if i > 0 { + F.NbWordsIndexesNoZero[i-1] = i + } + } + + return F +} + +type FFArm64 struct { + *arm64.Arm64 + nbElementsOnStack int + maxOnStack int + NbWords int + NbWordsLastIndex int + NbWordsIndexesFull []int + NbWordsIndexesNoZero []int + mDefines map[string]defineFn +} + +func GenerateFieldWrapper(w io.Writer, F *config.FieldConfig, asmDirBuildPath, asmDirIncludePath string) error { + // for each field we generate the defines for the modulus and the montgomery constant + f := NewFFArm64(w, F.NbWords) + + // we add the defines first, then the common asm, then the global variable section + // to enable correct compilations with #include in order. + f.WriteLn("") + + hashAndInclude := func(fileName string) error { + // we hash the file content and include the hash in comment of the generated file + // to force the Go compiler to recompile the file if the content has changed + fData, err := os.ReadFile(filepath.Join(asmDirBuildPath, fileName)) + if err != nil { + return err + } + // hash the file using FNV + hasher := fnv.New64() + hasher.Write(fData) + hash := hasher.Sum64() + + f.WriteLn("// Code generated by gnark-crypto/generator. DO NOT EDIT.") + f.WriteLn(fmt.Sprintf("// We include the hash to force the Go compiler to recompile: %d", hash)) + includePath := filepath.Join(asmDirIncludePath, fileName) + // on windows, we replace the "\" by "/" + if filepath.Separator == '\\' { + includePath = strings.ReplaceAll(includePath, "\\", "/") + } + f.WriteLn(fmt.Sprintf("#include \"%s\"\n", includePath)) + + return nil + } + + toInclude := fmt.Sprintf(ElementASMFileName, F.NbWords) + if err := hashAndInclude(toInclude); err != nil { + return err + } + + return nil +} + +// GenerateCommonASM generates assembly code for the base field provided to goff +// see internal/templates/ops* +func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { + f := NewFFArm64(w, nbWords) + f.Comment("Code generated by gnark-crypto/generator. DO NOT EDIT.") + + f.WriteLn("#include \"textflag.h\"") + f.WriteLn("#include \"funcdata.h\"") + f.WriteLn("#include \"go_asm.h\"") + f.WriteLn("") + + f.generateAdd() + f.generateSub() + f.generateDouble() + f.generateNeg() + + return nil +} + +// // Generate generates assembly code for the base field provided to goff +// // see internal/templates/ops* +// func Generate(w io.Writer, F *field.Field) error { +// f := NewFFArm64(w, F) +// f.WriteLn(bavard.Apache2Header("ConsenSys Software Inc.", 2020)) + +// f.WriteLn("#include \"textflag.h\"") +// f.WriteLn("#include \"funcdata.h\"\n") + +// f.generateStoreVector() + +// // add +// //TODO: It requires field size < 960 +// f.generateAdd() + +// // sub +// f.generateSub() + +// // double +// f.generateDouble() + +// // neg +// f.generateNeg() +// /* +// // reduce +// f.generateReduce() + +// // mul by constants +// f.generateMulBy3() +// f.generateMulBy5() +// f.generateMulBy13() + +// // fft butterflies +// f.generateButterfly()*/ + +// return nil +// } + +func (f *FFArm64) DefineFn(name string) (fn defineFn, err error) { + fn, ok := f.mDefines[name] + if !ok { + return nil, fmt.Errorf("function %s not defined", name) + } + return fn, nil +} + +func (f *FFArm64) Define(name string, nbInputs int, fn defineFn) defineFn { + + inputs := make([]string, nbInputs) + for i := 0; i < nbInputs; i++ { + inputs[i] = fmt.Sprintf("in%d", i) + } + name = strings.ToUpper(name) + + for _, ok := f.mDefines[name]; ok; { + // name already exist, for code generation purpose we add a suffix + // should happen only with e2 deprecated functions + fmt.Println("WARNING: function name already defined, adding suffix") + i := 0 + for { + newName := fmt.Sprintf("%s_%d", name, i) + if _, ok := f.mDefines[newName]; !ok { + name = newName + goto startDefine + } + i++ + } + } +startDefine: + + f.StartDefine() + f.WriteLn("#define " + name + "(" + strings.Join(inputs, ", ") + ")") + inputsRegisters := make([]arm64.Register, nbInputs) + for i := 0; i < nbInputs; i++ { + inputsRegisters[i] = arm64.Register(inputs[i]) + } + fn(inputsRegisters...) + f.EndDefine() + f.WriteLn("") + + toReturn := func(args ...arm64.Register) { + if len(args) != nbInputs { + panic("invalid number of arguments") + } + inputsStr := make([]string, len(args)) + for i := 0; i < len(args); i++ { + inputsStr[i] = string(args[i]) + } + f.WriteLn(name + "(" + strings.Join(inputsStr, ", ") + ")") + } + + f.mDefines[name] = toReturn + + return toReturn +} + +func (f *FFArm64) AssertCleanStack(reservedStackSize, minStackSize int) { + if f.nbElementsOnStack != 0 { + panic("missing f.Push stack elements") + } + if reservedStackSize < minStackSize { + panic("invalid minStackSize or reservedStackSize") + } + usedStackSize := f.maxOnStack * 8 + if usedStackSize > reservedStackSize { + panic("using more stack size than reserved") + } else if max(usedStackSize, minStackSize) < reservedStackSize { + // this panic is for dev purposes as this may be by design for aligment + panic("reserved more stack size than needed") + } + + f.maxOnStack = 0 +} + +func (f *FFArm64) Push(registers *arm64.Registers, rIn ...arm64.Register) { + for _, r := range rIn { + if strings.HasPrefix(string(r), "s") { + // it's on the stack, decrease the offset + f.nbElementsOnStack-- + continue + } + registers.Push(r) + } +} + +func (f *FFArm64) Pop(registers *arm64.Registers, forceStack ...bool) arm64.Register { + if registers.Available() >= 1 && !(len(forceStack) > 0 && forceStack[0]) { + return registers.Pop() + } + r := arm64.Register(fmt.Sprintf("s%d-%d(SP)", f.nbElementsOnStack, 8+f.nbElementsOnStack*8)) + f.nbElementsOnStack++ + if f.nbElementsOnStack > f.maxOnStack { + f.maxOnStack = f.nbElementsOnStack + } + return r +} + +func (f *FFArm64) PopN(registers *arm64.Registers, forceStack ...bool) []arm64.Register { + if len(forceStack) > 0 && forceStack[0] { + nbStack := f.NbWords + var u []arm64.Register + + for i := f.nbElementsOnStack; i < nbStack+f.nbElementsOnStack; i++ { + u = append(u, arm64.Register(fmt.Sprintf("s%d-%d(SP)", i, 8+i*8))) + } + f.nbElementsOnStack += nbStack + if f.nbElementsOnStack > f.maxOnStack { + f.maxOnStack = f.nbElementsOnStack + } + return u + } + if registers.Available() >= f.NbWords { + return registers.PopN(f.NbWords) + } + nbStack := f.NbWords - registers.Available() + u := registers.PopN(registers.Available()) + + for i := f.nbElementsOnStack; i < nbStack+f.nbElementsOnStack; i++ { + u = append(u, arm64.Register(fmt.Sprintf("s%d-%d(SP)", i, 8+i*8))) + } + f.nbElementsOnStack += nbStack + if f.nbElementsOnStack > f.maxOnStack { + f.maxOnStack = f.nbElementsOnStack + } + return u +} + +func (f *FFArm64) qAt(index int) string { + return fmt.Sprintf("·qElement+%d(SB)", index*8) +} + +func (f *FFArm64) qInv0() string { + return "$const_qInvNeg" +} diff --git a/field/generator/asm/arm64/element_ops.go b/field/generator/asm/arm64/element_ops.go new file mode 100644 index 0000000000..66bfe63979 --- /dev/null +++ b/field/generator/asm/arm64/element_ops.go @@ -0,0 +1,292 @@ +// Copyright 2022 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arm64 + +import ( + "github.com/consensys/bavard/arm64" +) + +func (f *FFArm64) generateAdd() { + if f.NbWords%2 != 0 { + panic("NbWords must be even") + } + f.Comment("add(res, x, y *Element)") + //stackSize := f.StackSize(f.NbWords*2, 0, 0) + registers := f.FnHeader("add", 0, 24) + defer f.AssertCleanStack(0, 0) + + // registers + z := registers.PopN(f.NbWords) + xPtr := registers.Pop() + yPtr := registers.Pop() + ops := registers.PopN(2) + + f.LDP("x+8(FP)", xPtr, yPtr) + f.Comment("load operands and add mod 2^r") + + op0 := f.ADDS + for i := 0; i < f.NbWords-1; i += 2 { + f.LDP(f.RegisterOffset(xPtr, 8*i), z[i], ops[0]) + f.LDP(f.RegisterOffset(yPtr, 8*i), z[i+1], ops[1]) + + op0(z[i], z[i+1], z[i]) + op0 = f.ADCS + + f.ADCS(ops[0], ops[1], z[i+1]) + } + + if f.NbWords%2 == 1 { + i := f.NbWords - 1 + f.MOVD(f.RegisterOffset(xPtr, 8*i), z[i], "can't import these in pairs") + f.MOVD(f.RegisterOffset(yPtr, 8*i), ops[0]) + op0(z[i], ops[0], z[i]) + } + registers.Push(xPtr, yPtr) + registers.Push(ops...) + + t := registers.PopN(f.NbWords) + f.reduce(z, t) + registers.Push(t...) + + f.Comment("store") + zPtr := registers.Pop() + f.MOVD("res+0(FP)", zPtr) + + storeVector := f.Define("storeVector", f.NbWords+1, func(args ...arm64.Register) { + res0 := args[0] + for i := 1; i < len(args); i += 2 { + f.STP(args[i], args[i+1], res0.At(8*(i-1))) + } + }) + // for i := 0; i < f.NbWords; i++ { + // f.MOVD(z[i], f.RegisterOffset(zPtr, 8*i)) + // } + _z := append([]arm64.Register{zPtr}, z...) + storeVector(_z...) + // f.storeVector(z, zPtr) + + f.RET() + +} + +func (f *FFArm64) generateDouble() { + f.Comment("double(res, x *Element)") + registers := f.FnHeader("double", 0, 16) + defer f.AssertCleanStack(0, 0) + + // registers + z := registers.PopN(f.NbWords) + xPtr := registers.Pop() + zPtr := registers.Pop() + //ops := registers.PopN(2) + + f.LDP("res+0(FP)", zPtr, xPtr) + f.Comment("load operands and add mod 2^r") + + op0 := f.ADDS + for i := 0; i < f.NbWords-1; i += 2 { + f.LDP(f.RegisterOffset(xPtr, 8*i), z[i], z[i+1]) + + op0(z[i], z[i], z[i]) + op0 = f.ADCS + + f.ADCS(z[i+1], z[i+1], z[i+1]) + } + + if f.NbWords%2 == 1 { + i := f.NbWords - 1 + f.MOVD(f.RegisterOffset(xPtr, 8*i), z[i]) + op0(z[i], z[i], z[i]) + } + registers.Push(xPtr) + + t := registers.PopN(f.NbWords) + f.reduce(z, t) + registers.Push(t...) + + f.Comment("store") + f.storeVector(z, zPtr) + + f.RET() + +} + +// generateSub NO LONGER uses one more register than generateAdd, but that's okay since we have 29 registers available. +func (f *FFArm64) generateSub() { + f.Comment("sub(res, x, y *Element)") + + registers := f.FnHeader("sub", 0, 24) + defer f.AssertCleanStack(0, 0) + + // registers + z := registers.PopN(f.NbWords) + xPtr := registers.Pop() + yPtr := registers.Pop() + ops := registers.PopN(2) + + f.LDP("x+8(FP)", xPtr, yPtr) + f.Comment("load operands and subtract mod 2^r") + + op0 := f.SUBS + for i := 0; i < f.NbWords-1; i += 2 { + f.LDP(f.RegisterOffset(xPtr, 8*i), z[i], ops[0]) + f.LDP(f.RegisterOffset(yPtr, 8*i), z[i+1], ops[1]) + + op0(z[i+1], z[i], z[i]) + op0 = f.SBCS + + f.SBCS(ops[1], ops[0], z[i+1]) + } + + if f.NbWords%2 == 1 { + i := f.NbWords - 1 + f.MOVD(f.RegisterOffset(xPtr, 8*i), z[i], "can't import these in pairs") + f.MOVD(f.RegisterOffset(yPtr, 8*i), ops[0]) + op0(ops[0], z[i], z[i]) + } + registers.Push(xPtr, yPtr) + registers.Push(ops...) + + f.Comment("load modulus and select") + + t := registers.PopN(f.NbWords) + zero := registers.Pop() + f.MOVD(0, zero) + + for i := 0; i < f.NbWords-1; i += 2 { + f.LDP(f.qAt(i), t[i], t[i+1]) + + f.CSEL("CS", zero, t[i], t[i]) + f.CSEL("CS", zero, t[i+1], t[i+1]) + } + + if f.NbWords%2 == 1 { + i := f.NbWords - 1 + f.MOVD(f.qAt(i), t[i]) + + f.CSEL("CS", zero, t[i], t[i]) + } + + registers.Push(zero) + + f.Comment("augment (or not)") + + op0 = f.ADDS + for i := 0; i < f.NbWords; i++ { + op0(z[i], t[i], z[i]) + op0 = f.ADCS + } + + registers.Push(t...) + + f.Comment("store") + zPtr := registers.Pop() + f.MOVD("res+0(FP)", zPtr) + f.storeVector(z, zPtr) + + f.RET() + +} + +func (f *FFArm64) generateNeg() { + f.Comment("neg(res, x *Element)") + registers := f.FnHeader("neg", 0, 16) + defer f.AssertCleanStack(0, 0) + + // registers + z := registers.PopN(f.NbWords) + xPtr := registers.Pop() + zPtr := registers.Pop() + ops := registers.PopN(2) + xNotZero := registers.Pop() + + f.LDP("res+0(FP)", zPtr, xPtr) + f.Comment("load operands and subtract") + + f.MOVD(0, xNotZero) + op0 := f.SUBS + for i := 0; i < f.NbWords-1; i += 2 { + f.LDP(f.RegisterOffset(xPtr, 8*i), z[i], z[i+1]) + f.LDP(f.qAt(i), ops[0], ops[1]) + + f.ORR(z[i], xNotZero, xNotZero, "has x been 0 so far?") + f.ORR(z[i+1], xNotZero, xNotZero) + + op0(z[i], ops[0], z[i]) + op0 = f.SBCS + + f.SBCS(z[i+1], ops[1], z[i+1]) + } + + if f.NbWords%2 == 1 { + i := f.NbWords - 1 + f.MOVD(f.RegisterOffset(xPtr, 8*i), z[i], "can't import these in pairs") + f.MOVD(f.qAt(i), ops[0]) + + f.ORR(z[i], xNotZero, xNotZero) + + op0(z[i], ops[0], z[i]) + } + + registers.Push(xPtr) + registers.Push(ops...) + + f.TST(-1, xNotZero) + for i := 0; i < f.NbWords; i++ { + f.CSEL("EQ", xNotZero, z[i], z[i]) + } + + f.Comment("store") + f.storeVector(z, zPtr) + + f.RET() + +} + +func (f *FFArm64) reduce(z, t []arm64.Register) { + + if len(z) != f.NbWords || len(t) != f.NbWords { + panic("need 2*nbWords registers") + } + + f.Comment("load modulus and subtract") + + op0 := f.SUBS + for i := 0; i < f.NbWords-1; i += 2 { + f.LDP(f.qAt(i), t[i], t[i+1]) + + op0(t[i], z[i], t[i]) + op0 = f.SBCS + + f.SBCS(t[i+1], z[i+1], t[i+1]) + } + + if f.NbWords%2 == 1 { + i := f.NbWords - 1 + f.MOVD(f.qAt(i), t[i]) + + op0(t[i], z[i], t[i]) + } + + f.Comment("reduce if necessary") + + for i := 0; i < f.NbWords; i++ { + f.CSEL("CS", t[i], z[i], z[i]) + } +} + +func (f *FFArm64) storeVector(vector interface{}, baseAddress arm64.Register) { + // f.callTemplate("storeVector", toInterfaceSlice(baseAddress, vector)...) +} diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index 6d24aab743..9b0d8d47d0 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -54,6 +54,7 @@ type FieldConfig struct { Mu uint64 // mu = 2^288 / q for 4.5 word barrett reduction ASM bool ASMVector bool + ASMArm bool RSquare []uint64 One, Thirteen []uint64 LegendreExponent string // big.Int to base16 string @@ -263,6 +264,7 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) // asm code generation for moduli with more than 6 words can be optimized further F.ASM = F.NoCarry && F.NbWords <= 12 && F.NbWords > 1 F.ASMVector = F.ASM && F.NbWords == 4 && F.NbBits > 225 + F.ASMArm = F.ASMVector // setting Mu 2^288 / q if F.NbWords == 4 { diff --git a/field/generator/generator.go b/field/generator/generator.go index 0150897677..e9c86dfdd7 100644 --- a/field/generator/generator.go +++ b/field/generator/generator.go @@ -11,6 +11,7 @@ import ( "github.com/consensys/bavard" "github.com/consensys/gnark-crypto/field/generator/asm/amd64" + "github.com/consensys/gnark-crypto/field/generator/asm/arm64" "github.com/consensys/gnark-crypto/field/generator/config" "github.com/consensys/gnark-crypto/field/generator/internal/addchain" "github.com/consensys/gnark-crypto/field/generator/internal/templates/element" @@ -157,6 +158,36 @@ func GenerateFF(F *config.FieldConfig, outputDir, asmDirBuildPath, asmDirInclude } + if F.ASMArm { + // generate ops.s + { + pathSrc := filepath.Join(outputDir, eName+"_ops_arm64.s") + fmt.Println("generating", pathSrc) + f, err := os.Create(pathSrc) + if err != nil { + return err + } + + _, _ = io.WriteString(f, "// +build !purego\n") + + if err := arm64.GenerateFieldWrapper(f, F, asmDirBuildPath, asmDirIncludePath); err != nil { + _ = f.Close() + return err + } + _ = f.Close() + + // run asmfmt + // run go fmt on whole directory + cmd := exec.Command("asmfmt", "-w", pathSrc) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return err + } + } + + } + if F.ASM { // generate ops_amd64.go src := []string{ @@ -174,6 +205,23 @@ func GenerateFF(F *config.FieldConfig, outputDir, asmDirBuildPath, asmDirInclude } } + if F.ASMArm { + // generate ops_arm64.go + src := []string{ + element.MulDoc, + element.OpsARM64, + } + pathSrc := filepath.Join(outputDir, eName+"_ops_arm64.go") + bavardOptsCpy := make([]func(*bavard.Bavard) error, len(bavardOpts)) + copy(bavardOptsCpy, bavardOpts) + if F.ASM { + bavardOptsCpy = append(bavardOptsCpy, bavard.BuildTag("!purego")) + } + if err := bavard.GenerateFromString(pathSrc, src, F, bavardOptsCpy...); err != nil { + return err + } + } + { // generate ops.go src := []string{ @@ -279,29 +327,56 @@ func shorten(input string) string { return input } -func GenerateCommonASM(nbWords int, asmDir string, hasVector bool) error { +func GenerateCommonASM(nbWords int, asmDir string, hasVector bool, hasArm bool) error { os.MkdirAll(asmDir, 0755) - pathSrc := filepath.Join(asmDir, fmt.Sprintf(amd64.ElementASMFileName, nbWords)) + { + pathSrc := filepath.Join(asmDir, fmt.Sprintf(amd64.ElementASMFileName, nbWords)) - fmt.Println("generating", pathSrc) - f, err := os.Create(pathSrc) - if err != nil { - return err - } + fmt.Println("generating", pathSrc) + f, err := os.Create(pathSrc) + if err != nil { + return err + } - if err := amd64.GenerateCommonASM(f, nbWords, hasVector); err != nil { + if err := amd64.GenerateCommonASM(f, nbWords, hasVector); err != nil { + _ = f.Close() + return err + } _ = f.Close() - return err + + // run asmfmt + // run go fmt on whole directory + cmd := exec.Command("asmfmt", "-w", pathSrc) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return err + } } - _ = f.Close() - // run asmfmt - // run go fmt on whole directory - cmd := exec.Command("asmfmt", "-w", pathSrc) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return err + if hasArm { + pathSrc := filepath.Join(asmDir, fmt.Sprintf(arm64.ElementASMFileName, nbWords)) + + fmt.Println("generating", pathSrc) + f, err := os.Create(pathSrc) + if err != nil { + return err + } + + if err := arm64.GenerateCommonASM(f, nbWords, hasVector); err != nil { + _ = f.Close() + return err + } + _ = f.Close() + + // run asmfmt + // run go fmt on whole directory + cmd := exec.Command("asmfmt", "-w", pathSrc) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return err + } } return nil diff --git a/field/generator/generator_test.go b/field/generator/generator_test.go index 2c4b055f65..3da09033a3 100644 --- a/field/generator/generator_test.go +++ b/field/generator/generator_test.go @@ -79,10 +79,10 @@ func TestIntegration(t *testing.T) { moduli["e_nocarry_edge_0127"] = "170141183460469231731687303715884105727" moduli["e_nocarry_edge_1279"] = "10407932194664399081925240327364085538615262247266704805319112350403608059673360298012239441732324184842421613954281007791383566248323464908139906605677320762924129509389220345773183349661583550472959420547689811211693677147548478866962501384438260291732348885311160828538416585028255604666224831890918801847068222203140521026698435488732958028878050869736186900714720710555703168729087" - assert.NoError(GenerateCommonASM(2, asmDir, false)) - assert.NoError(GenerateCommonASM(3, asmDir, false)) - assert.NoError(GenerateCommonASM(7, asmDir, false)) - assert.NoError(GenerateCommonASM(8, asmDir, false)) + assert.NoError(GenerateCommonASM(2, asmDir, false, false)) + assert.NoError(GenerateCommonASM(3, asmDir, false, false)) + assert.NoError(GenerateCommonASM(7, asmDir, false, false)) + assert.NoError(GenerateCommonASM(8, asmDir, false, false)) for elementName, modulus := range moduli { var fIntegration *field.FieldConfig diff --git a/field/generator/internal/templates/element/base.go b/field/generator/internal/templates/element/base.go index 7dc8bffc4e..f54a5181d6 100644 --- a/field/generator/internal/templates/element/base.go +++ b/field/generator/internal/templates/element/base.go @@ -396,6 +396,7 @@ func (z *{{.ElementName}}) fromMont() *{{.ElementName}} { return z } +{{- if not .ASMArm}} // Add z = x + y (mod q) func (z *{{.ElementName}}) Add( x, y *{{.ElementName}}) *{{.ElementName}} { {{ $hasCarry := or (not $.NoCarry) (gt $.NbWords 1)}} @@ -430,6 +431,7 @@ func (z *{{.ElementName}}) Add( x, y *{{.ElementName}}) *{{.ElementName}} { {{- end}} return z } +{{- end}} // Double z = x + x (mod q), aka Lsh 1 func (z *{{.ElementName}}) Double( x *{{.ElementName}}) *{{.ElementName}} { diff --git a/field/generator/internal/templates/element/ops_asm.go b/field/generator/internal/templates/element/ops_asm.go index 1d1408e264..e07c71e628 100644 --- a/field/generator/internal/templates/element/ops_asm.go +++ b/field/generator/internal/templates/element/ops_asm.go @@ -195,4 +195,16 @@ func (z *{{.ElementName}}) Square(x *{{.ElementName}}) *{{.ElementName}} { +` + +const OpsARM64 = ` +{{if .ASMArm}} +//go:noescape +func add(res,x,y *{{.ElementName}}) + +func (z *{{.ElementName}}) Add( x, y *{{.ElementName}}) *{{.ElementName}} { + add(z,x,y) + return z +} +{{end}} ` diff --git a/go.mod b/go.mod index b5486971f2..f7ed17e228 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/bits-and-blooms/bitset v1.14.2 - github.com/consensys/bavard v0.1.22 + github.com/consensys/bavard v0.0.0 github.com/leanovate/gopter v0.2.11 github.com/mmcloughlin/addchain v0.4.0 github.com/spf13/cobra v1.8.1 @@ -15,6 +15,8 @@ require ( gopkg.in/yaml.v2 v2.4.0 ) +replace github.com/consensys/bavard => ../bavard + require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index 26af63d69f..18a8ee2e2b 100644 --- a/go.sum +++ b/go.sum @@ -55,8 +55,6 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/consensys/bavard v0.1.22 h1:Uw2CGvbXSZWhqK59X0VG/zOjpTFuOMcPLStrp1ihI0A= -github.com/consensys/bavard v0.1.22/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= diff --git a/internal/generator/main.go b/internal/generator/main.go index a9fe8390ca..2f84a433e7 100644 --- a/internal/generator/main.go +++ b/internal/generator/main.go @@ -54,6 +54,7 @@ func main() { // generate common assembly files depending on field number of words mCommon := make(map[int]bool) mVec := make(map[int]bool) + mArm := make(map[int]bool) for i, conf := range config.Curves { var err error @@ -73,12 +74,18 @@ func main() { if conf.Fp.ASMVector { mVec[conf.Fp.NbWords] = true } + if conf.Fr.ASMArm { + mArm[conf.Fr.NbWords] = true + } + if conf.Fp.ASMArm { + mArm[conf.Fp.NbWords] = true + } config.Curves[i] = conf } for nbWords := range mCommon { - assertNoError(generator.GenerateCommonASM(nbWords, asmDirBuildPath, mVec[nbWords])) + assertNoError(generator.GenerateCommonASM(nbWords, asmDirBuildPath, mVec[nbWords], mArm[nbWords])) } var wg sync.WaitGroup From c60ce4c7016988f44d2241b4669392d85dd1e783 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 15 Oct 2024 17:05:08 -0500 Subject: [PATCH 02/74] checkpoint --- ecc/bls12-377/fr/element_ops_arm64.s | 2 +- ecc/bls12-381/fr/element_ops_arm64.s | 2 +- ecc/bls24-315/fr/element_ops_arm64.s | 2 +- ecc/bls24-317/fr/element_ops_arm64.s | 2 +- ecc/bn254/fp/element_ops_arm64.s | 2 +- ecc/bn254/fr/element_ops_arm64.s | 2 +- ecc/stark-curve/fp/element_ops_arm64.s | 2 +- ecc/stark-curve/fr/element_ops_arm64.s | 2 +- field/asm/element_4w_arm64.s | 4 +-- field/generator/asm/arm64/build.go | 4 +++ field/generator/asm/arm64/element_ops.go | 45 +++++------------------- 11 files changed, 22 insertions(+), 47 deletions(-) diff --git a/ecc/bls12-377/fr/element_ops_arm64.s b/ecc/bls12-377/fr/element_ops_arm64.s index 5d79f98f87..d74d9753b8 100644 --- a/ecc/bls12-377/fr/element_ops_arm64.s +++ b/ecc/bls12-377/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 17172654935612186478 +// We include the hash to force the Go compiler to recompile: 9776471158029765872 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls12-381/fr/element_ops_arm64.s b/ecc/bls12-381/fr/element_ops_arm64.s index 5d79f98f87..d74d9753b8 100644 --- a/ecc/bls12-381/fr/element_ops_arm64.s +++ b/ecc/bls12-381/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 17172654935612186478 +// We include the hash to force the Go compiler to recompile: 9776471158029765872 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-315/fr/element_ops_arm64.s b/ecc/bls24-315/fr/element_ops_arm64.s index 5d79f98f87..d74d9753b8 100644 --- a/ecc/bls24-315/fr/element_ops_arm64.s +++ b/ecc/bls24-315/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 17172654935612186478 +// We include the hash to force the Go compiler to recompile: 9776471158029765872 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-317/fr/element_ops_arm64.s b/ecc/bls24-317/fr/element_ops_arm64.s index 5d79f98f87..d74d9753b8 100644 --- a/ecc/bls24-317/fr/element_ops_arm64.s +++ b/ecc/bls24-317/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 17172654935612186478 +// We include the hash to force the Go compiler to recompile: 9776471158029765872 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fp/element_ops_arm64.s b/ecc/bn254/fp/element_ops_arm64.s index 5d79f98f87..d74d9753b8 100644 --- a/ecc/bn254/fp/element_ops_arm64.s +++ b/ecc/bn254/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 17172654935612186478 +// We include the hash to force the Go compiler to recompile: 9776471158029765872 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fr/element_ops_arm64.s b/ecc/bn254/fr/element_ops_arm64.s index 5d79f98f87..d74d9753b8 100644 --- a/ecc/bn254/fr/element_ops_arm64.s +++ b/ecc/bn254/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 17172654935612186478 +// We include the hash to force the Go compiler to recompile: 9776471158029765872 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fp/element_ops_arm64.s b/ecc/stark-curve/fp/element_ops_arm64.s index 5d79f98f87..d74d9753b8 100644 --- a/ecc/stark-curve/fp/element_ops_arm64.s +++ b/ecc/stark-curve/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 17172654935612186478 +// We include the hash to force the Go compiler to recompile: 9776471158029765872 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fr/element_ops_arm64.s b/ecc/stark-curve/fr/element_ops_arm64.s index 5d79f98f87..d74d9753b8 100644 --- a/ecc/stark-curve/fr/element_ops_arm64.s +++ b/ecc/stark-curve/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 17172654935612186478 +// We include the hash to force the Go compiler to recompile: 9776471158029765872 #include "../../../field/asm/element_4w_arm64.s" diff --git a/field/asm/element_4w_arm64.s b/field/asm/element_4w_arm64.s index 6edffb51cf..d4e2c9ccd2 100644 --- a/field/asm/element_4w_arm64.s +++ b/field/asm/element_4w_arm64.s @@ -35,8 +35,8 @@ TEXT ·add(SB), NOSPLIT, $0-24 MOVD res+0(FP), R12 #define STOREVECTOR(in0, in1, in2, in3, in4) \ - STP (in1, in2), 0(in0) \ - STP (in3, in4), 128(in0) \ + STP (in1, in2), 0(in0) \ + STP (in3, in4), 16(in0) \ STOREVECTOR(R12, R0, R1, R2, R3) RET diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index 2b0a628a8d..d88428bd2a 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -103,6 +103,10 @@ func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { f.WriteLn("#include \"go_asm.h\"") f.WriteLn("") + if f.NbWords%2 != 0 { + panic("NbWords must be even") + } + f.generateAdd() f.generateSub() f.generateDouble() diff --git a/field/generator/asm/arm64/element_ops.go b/field/generator/asm/arm64/element_ops.go index 66bfe63979..2b55450a3b 100644 --- a/field/generator/asm/arm64/element_ops.go +++ b/field/generator/asm/arm64/element_ops.go @@ -19,9 +19,6 @@ import ( ) func (f *FFArm64) generateAdd() { - if f.NbWords%2 != 0 { - panic("NbWords must be even") - } f.Comment("add(res, x, y *Element)") //stackSize := f.StackSize(f.NbWords*2, 0, 0) registers := f.FnHeader("add", 0, 24) @@ -38,8 +35,8 @@ func (f *FFArm64) generateAdd() { op0 := f.ADDS for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(f.RegisterOffset(xPtr, 8*i), z[i], ops[0]) - f.LDP(f.RegisterOffset(yPtr, 8*i), z[i+1], ops[1]) + f.LDP(xPtr.At(i), z[i], ops[0]) + f.LDP(yPtr.At(i), z[i+1], ops[1]) op0(z[i], z[i+1], z[i]) op0 = f.ADCS @@ -47,12 +44,6 @@ func (f *FFArm64) generateAdd() { f.ADCS(ops[0], ops[1], z[i+1]) } - if f.NbWords%2 == 1 { - i := f.NbWords - 1 - f.MOVD(f.RegisterOffset(xPtr, 8*i), z[i], "can't import these in pairs") - f.MOVD(f.RegisterOffset(yPtr, 8*i), ops[0]) - op0(z[i], ops[0], z[i]) - } registers.Push(xPtr, yPtr) registers.Push(ops...) @@ -67,15 +58,11 @@ func (f *FFArm64) generateAdd() { storeVector := f.Define("storeVector", f.NbWords+1, func(args ...arm64.Register) { res0 := args[0] for i := 1; i < len(args); i += 2 { - f.STP(args[i], args[i+1], res0.At(8*(i-1))) + f.STP(args[i], args[i+1], res0.At(i-1)) } }) - // for i := 0; i < f.NbWords; i++ { - // f.MOVD(z[i], f.RegisterOffset(zPtr, 8*i)) - // } _z := append([]arm64.Register{zPtr}, z...) storeVector(_z...) - // f.storeVector(z, zPtr) f.RET() @@ -97,7 +84,7 @@ func (f *FFArm64) generateDouble() { op0 := f.ADDS for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(f.RegisterOffset(xPtr, 8*i), z[i], z[i+1]) + f.LDP(xPtr.At(i), z[i], z[i+1]) op0(z[i], z[i], z[i]) op0 = f.ADCS @@ -107,7 +94,7 @@ func (f *FFArm64) generateDouble() { if f.NbWords%2 == 1 { i := f.NbWords - 1 - f.MOVD(f.RegisterOffset(xPtr, 8*i), z[i]) + f.MOVD(xPtr.At(i), z[i]) op0(z[i], z[i], z[i]) } registers.Push(xPtr) @@ -141,8 +128,8 @@ func (f *FFArm64) generateSub() { op0 := f.SUBS for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(f.RegisterOffset(xPtr, 8*i), z[i], ops[0]) - f.LDP(f.RegisterOffset(yPtr, 8*i), z[i+1], ops[1]) + f.LDP(xPtr.At(i), z[i], ops[0]) + f.LDP(yPtr.At(i), z[i+1], ops[1]) op0(z[i+1], z[i], z[i]) op0 = f.SBCS @@ -150,12 +137,6 @@ func (f *FFArm64) generateSub() { f.SBCS(ops[1], ops[0], z[i+1]) } - if f.NbWords%2 == 1 { - i := f.NbWords - 1 - f.MOVD(f.RegisterOffset(xPtr, 8*i), z[i], "can't import these in pairs") - f.MOVD(f.RegisterOffset(yPtr, 8*i), ops[0]) - op0(ops[0], z[i], z[i]) - } registers.Push(xPtr, yPtr) registers.Push(ops...) @@ -218,7 +199,7 @@ func (f *FFArm64) generateNeg() { f.MOVD(0, xNotZero) op0 := f.SUBS for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(f.RegisterOffset(xPtr, 8*i), z[i], z[i+1]) + f.LDP(xPtr.At(i), z[i], z[i+1]) f.LDP(f.qAt(i), ops[0], ops[1]) f.ORR(z[i], xNotZero, xNotZero, "has x been 0 so far?") @@ -230,16 +211,6 @@ func (f *FFArm64) generateNeg() { f.SBCS(z[i+1], ops[1], z[i+1]) } - if f.NbWords%2 == 1 { - i := f.NbWords - 1 - f.MOVD(f.RegisterOffset(xPtr, 8*i), z[i], "can't import these in pairs") - f.MOVD(f.qAt(i), ops[0]) - - f.ORR(z[i], xNotZero, xNotZero) - - op0(z[i], ops[0], z[i]) - } - registers.Push(xPtr) registers.Push(ops...) From 5832e15126575f6ad8987d65c4d68778f3a36ff0 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 15 Oct 2024 17:11:40 -0500 Subject: [PATCH 03/74] build: update bavard --- go.mod | 4 +--- go.sum | 2 ++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index f7ed17e228..9006ce5c13 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/bits-and-blooms/bitset v1.14.2 - github.com/consensys/bavard v0.0.0 + github.com/consensys/bavard v0.1.23-0.20241015221109-a56d5bf777eb github.com/leanovate/gopter v0.2.11 github.com/mmcloughlin/addchain v0.4.0 github.com/spf13/cobra v1.8.1 @@ -15,8 +15,6 @@ require ( gopkg.in/yaml.v2 v2.4.0 ) -replace github.com/consensys/bavard => ../bavard - require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index 18a8ee2e2b..1164581955 100644 --- a/go.sum +++ b/go.sum @@ -55,6 +55,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/consensys/bavard v0.1.23-0.20241015221109-a56d5bf777eb h1:yPPmCz5FvvKMAKz/O7t5qJNZcEA0q6ermJzoL2D0oQU= +github.com/consensys/bavard v0.1.23-0.20241015221109-a56d5bf777eb/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= From ec77f9cb18822c4fa4a4ca4de9512a56e0e05747 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 15 Oct 2024 22:34:50 +0000 Subject: [PATCH 04/74] checkpoint --- ecc/bls12-377/fr/element_ops_arm64.s | 2 +- ecc/bls12-381/fr/element_ops_arm64.s | 2 +- ecc/bls24-315/fr/element_ops_arm64.s | 2 +- ecc/bls24-317/fr/element_ops_arm64.s | 2 +- ecc/bn254/fp/element_ops_arm64.s | 2 +- ecc/bn254/fr/element_ops_arm64.s | 2 +- ecc/stark-curve/fp/element_ops_arm64.s | 2 +- ecc/stark-curve/fr/element_ops_arm64.s | 2 +- field/asm/element_4w_arm64.s | 139 ++++------------------- field/generator/asm/arm64/build.go | 10 +- field/generator/asm/arm64/element_ops.go | 54 +++------ 11 files changed, 57 insertions(+), 162 deletions(-) diff --git a/ecc/bls12-377/fr/element_ops_arm64.s b/ecc/bls12-377/fr/element_ops_arm64.s index d74d9753b8..6683e0c8bd 100644 --- a/ecc/bls12-377/fr/element_ops_arm64.s +++ b/ecc/bls12-377/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9776471158029765872 +// We include the hash to force the Go compiler to recompile: 11169582256709554223 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls12-381/fr/element_ops_arm64.s b/ecc/bls12-381/fr/element_ops_arm64.s index d74d9753b8..6683e0c8bd 100644 --- a/ecc/bls12-381/fr/element_ops_arm64.s +++ b/ecc/bls12-381/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9776471158029765872 +// We include the hash to force the Go compiler to recompile: 11169582256709554223 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-315/fr/element_ops_arm64.s b/ecc/bls24-315/fr/element_ops_arm64.s index d74d9753b8..6683e0c8bd 100644 --- a/ecc/bls24-315/fr/element_ops_arm64.s +++ b/ecc/bls24-315/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9776471158029765872 +// We include the hash to force the Go compiler to recompile: 11169582256709554223 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-317/fr/element_ops_arm64.s b/ecc/bls24-317/fr/element_ops_arm64.s index d74d9753b8..6683e0c8bd 100644 --- a/ecc/bls24-317/fr/element_ops_arm64.s +++ b/ecc/bls24-317/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9776471158029765872 +// We include the hash to force the Go compiler to recompile: 11169582256709554223 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fp/element_ops_arm64.s b/ecc/bn254/fp/element_ops_arm64.s index d74d9753b8..6683e0c8bd 100644 --- a/ecc/bn254/fp/element_ops_arm64.s +++ b/ecc/bn254/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9776471158029765872 +// We include the hash to force the Go compiler to recompile: 11169582256709554223 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fr/element_ops_arm64.s b/ecc/bn254/fr/element_ops_arm64.s index d74d9753b8..6683e0c8bd 100644 --- a/ecc/bn254/fr/element_ops_arm64.s +++ b/ecc/bn254/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9776471158029765872 +// We include the hash to force the Go compiler to recompile: 11169582256709554223 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fp/element_ops_arm64.s b/ecc/stark-curve/fp/element_ops_arm64.s index d74d9753b8..6683e0c8bd 100644 --- a/ecc/stark-curve/fp/element_ops_arm64.s +++ b/ecc/stark-curve/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9776471158029765872 +// We include the hash to force the Go compiler to recompile: 11169582256709554223 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fr/element_ops_arm64.s b/ecc/stark-curve/fr/element_ops_arm64.s index d74d9753b8..6683e0c8bd 100644 --- a/ecc/stark-curve/fr/element_ops_arm64.s +++ b/ecc/stark-curve/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9776471158029765872 +// We include the hash to force the Go compiler to recompile: 11169582256709554223 #include "../../../field/asm/element_4w_arm64.s" diff --git a/field/asm/element_4w_arm64.s b/field/asm/element_4w_arm64.s index d4e2c9ccd2..5e04a59e91 100644 --- a/field/asm/element_4w_arm64.s +++ b/field/asm/element_4w_arm64.s @@ -5,127 +5,36 @@ // add(res, x, y *Element) TEXT ·add(SB), NOSPLIT, $0-24 - LDP x+8(FP), (R4, R5) + LDP x+8(FP), (R12, R13) // load operands and add mod 2^r - LDP 0(R4), (R0, R6) - LDP 0(R5), (R1, R7) - ADDS R0, R1, R0 - ADCS R6, R7, R1 - LDP 16(R4), (R2, R6) - LDP 16(R5), (R3, R7) - ADCS R2, R3, R2 - ADCS R6, R7, R3 + LDP 0(R12), (R8, R9) + LDP 0(R13), (R4, R5) + LDP 16(R12), (R10, R11) + LDP 16(R13), (R6, R7) + ADDS R8, R4, R4 + ADCS R9, R5, R5 + ADCS R10, R6, R6 + ADCS R11, R7, R7 // load modulus and subtract - LDP ·qElement+0(SB), (R8, R9) - SUBS R8, R0, R8 - SBCS R9, R1, R9 - LDP ·qElement+16(SB), (R10, R11) - SBCS R10, R2, R10 - SBCS R11, R3, R11 - - // reduce if necessary - CSEL CS, R8, R0, R0 - CSEL CS, R9, R1, R1 - CSEL CS, R10, R2, R2 - CSEL CS, R11, R3, R3 - - // store - MOVD res+0(FP), R12 - -#define STOREVECTOR(in0, in1, in2, in3, in4) \ - STP (in1, in2), 0(in0) \ - STP (in3, in4), 16(in0) \ - - STOREVECTOR(R12, R0, R1, R2, R3) - RET - -// sub(res, x, y *Element) -TEXT ·sub(SB), NOSPLIT, $0-24 - LDP x+8(FP), (R4, R5) - - // load operands and subtract mod 2^r - LDP 0(R4), (R0, R6) - LDP 0(R5), (R1, R7) - SUBS R1, R0, R0 - SBCS R7, R6, R1 - LDP 16(R4), (R2, R6) - LDP 16(R5), (R3, R7) - SBCS R3, R2, R2 - SBCS R7, R6, R3 - - // load modulus and select - MOVD $0, R12 - LDP ·qElement+0(SB), (R8, R9) - CSEL CS, R12, R8, R8 - CSEL CS, R12, R9, R9 - LDP ·qElement+16(SB), (R10, R11) - CSEL CS, R12, R10, R10 - CSEL CS, R12, R11, R11 - - // augment (or not) - ADDS R0, R8, R0 - ADCS R1, R9, R1 - ADCS R2, R10, R2 - ADCS R3, R11, R3 - - // store - MOVD res+0(FP), R13 - RET - -// double(res, x *Element) -TEXT ·double(SB), NOSPLIT, $0-16 - LDP res+0(FP), (R5, R4) - - // load operands and add mod 2^r - LDP 0(R4), (R0, R1) - ADDS R0, R0, R0 - ADCS R1, R1, R1 - LDP 16(R4), (R2, R3) - ADCS R2, R2, R2 - ADCS R3, R3, R3 - - // load modulus and subtract - LDP ·qElement+0(SB), (R6, R7) - SUBS R6, R0, R6 - SBCS R7, R1, R7 - LDP ·qElement+16(SB), (R8, R9) - SBCS R8, R2, R8 - SBCS R9, R3, R9 - - // reduce if necessary - CSEL CS, R6, R0, R0 - CSEL CS, R7, R1, R1 - CSEL CS, R8, R2, R2 - CSEL CS, R9, R3, R3 - - // store - RET - -// neg(res, x *Element) -TEXT ·neg(SB), NOSPLIT, $0-16 - LDP res+0(FP), (R5, R4) - - // load operands and subtract - MOVD $0, R8 - LDP 0(R4), (R0, R1) - LDP ·qElement+0(SB), (R6, R7) - ORR R0, R8, R8 // has x been 0 so far? - ORR R1, R8, R8 - SUBS R0, R6, R0 - SBCS R1, R7, R1 - LDP 16(R4), (R2, R3) - LDP ·qElement+16(SB), (R6, R7) - ORR R2, R8, R8 // has x been 0 so far? - ORR R3, R8, R8 + MOVD $const_q0, R0 + MOVD $const_q1, R1 + MOVD $const_q2, R2 + MOVD $const_q3, R3 + SUBS R0, R4, R0 + SBCS R1, R5, R1 SBCS R2, R6, R2 SBCS R3, R7, R3 - TST $0xffffffffffffffff, R8 - CSEL EQ, R8, R0, R0 - CSEL EQ, R8, R1, R1 - CSEL EQ, R8, R2, R2 - CSEL EQ, R8, R3, R3 + + // reduce if necessary + CSEL CS, R0, R4, R4 + CSEL CS, R1, R5, R5 + CSEL CS, R2, R6, R6 + CSEL CS, R3, R7, R7 // store + MOVD res+0(FP), R14 + STP (R4, R5), 0(R14) + STP (R6, R7), 16(R14) RET diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index d88428bd2a..0839fe9ef8 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -108,9 +108,9 @@ func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { } f.generateAdd() - f.generateSub() - f.generateDouble() - f.generateNeg() + // f.generateSub() + // f.generateDouble() + // f.generateNeg() return nil } @@ -289,3 +289,7 @@ func (f *FFArm64) qAt(index int) string { func (f *FFArm64) qInv0() string { return "$const_qInvNeg" } + +func (f *FFArm64) qi(i int) string { + return fmt.Sprintf("$const_q%d", i) +} diff --git a/field/generator/asm/arm64/element_ops.go b/field/generator/asm/arm64/element_ops.go index 2b55450a3b..3ca30f90c3 100644 --- a/field/generator/asm/arm64/element_ops.go +++ b/field/generator/asm/arm64/element_ops.go @@ -25,44 +25,35 @@ func (f *FFArm64) generateAdd() { defer f.AssertCleanStack(0, 0) // registers + t := registers.PopN(f.NbWords) z := registers.PopN(f.NbWords) + x := registers.PopN(f.NbWords) xPtr := registers.Pop() yPtr := registers.Pop() - ops := registers.PopN(2) + zPtr := registers.Pop() f.LDP("x+8(FP)", xPtr, yPtr) f.Comment("load operands and add mod 2^r") - op0 := f.ADDS for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(xPtr.At(i), z[i], ops[0]) - f.LDP(yPtr.At(i), z[i+1], ops[1]) - - op0(z[i], z[i+1], z[i]) - op0 = f.ADCS - - f.ADCS(ops[0], ops[1], z[i+1]) + f.LDP(xPtr.At(i), x[i], x[i+1]) + f.LDP(yPtr.At(i), z[i], z[i+1]) } - registers.Push(xPtr, yPtr) - registers.Push(ops...) + f.ADDS(x[0], z[0], z[0]) + for i := 1; i < f.NbWords; i++ { + f.ADCS(x[i], z[i], z[i]) + } - t := registers.PopN(f.NbWords) f.reduce(z, t) - registers.Push(t...) f.Comment("store") - zPtr := registers.Pop() + f.MOVD("res+0(FP)", zPtr) - storeVector := f.Define("storeVector", f.NbWords+1, func(args ...arm64.Register) { - res0 := args[0] - for i := 1; i < len(args); i += 2 { - f.STP(args[i], args[i+1], res0.At(i-1)) - } - }) - _z := append([]arm64.Register{zPtr}, z...) - storeVector(_z...) + for i := 0; i < f.NbWords-1; i += 2 { + f.STP(z[i], z[i+1], zPtr.At(i)) + } f.RET() @@ -234,21 +225,12 @@ func (f *FFArm64) reduce(z, t []arm64.Register) { f.Comment("load modulus and subtract") - op0 := f.SUBS - for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(f.qAt(i), t[i], t[i+1]) - - op0(t[i], z[i], t[i]) - op0 = f.SBCS - - f.SBCS(t[i+1], z[i+1], t[i+1]) + for i := 0; i < f.NbWords; i++ { + f.MOVD(f.qi(i), t[i]) } - - if f.NbWords%2 == 1 { - i := f.NbWords - 1 - f.MOVD(f.qAt(i), t[i]) - - op0(t[i], z[i], t[i]) + f.SUBS(t[0], z[0], t[0]) + for i := 1; i < f.NbWords; i++ { + f.SBCS(t[i], z[i], t[i]) } f.Comment("reduce if necessary") From c9363b5ae70e5eb16d81f4af80aa651b7f545456 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Wed, 16 Oct 2024 15:45:28 +0000 Subject: [PATCH 05/74] checkpoint --- ecc/bls12-377/fr/element_ops_arm64.s | 2 +- ecc/bls12-381/fr/element_ops_arm64.s | 2 +- ecc/bls24-315/fr/element_ops_arm64.s | 2 +- ecc/bls24-317/fr/element_ops_arm64.s | 2 +- ecc/bn254/fp/element_ops_arm64.s | 2 +- ecc/bn254/fr/element_ops_arm64.s | 2 +- ecc/stark-curve/fp/element_ops_arm64.s | 2 +- ecc/stark-curve/fr/element_ops_arm64.s | 2 +- field/asm/element_4w_arm64.s | 6 ++---- field/generator/asm/arm64/element_ops.go | 5 +++-- 10 files changed, 13 insertions(+), 14 deletions(-) diff --git a/ecc/bls12-377/fr/element_ops_arm64.s b/ecc/bls12-377/fr/element_ops_arm64.s index 6683e0c8bd..0e83c291a5 100644 --- a/ecc/bls12-377/fr/element_ops_arm64.s +++ b/ecc/bls12-377/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 11169582256709554223 +// We include the hash to force the Go compiler to recompile: 10009927169843392352 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls12-381/fr/element_ops_arm64.s b/ecc/bls12-381/fr/element_ops_arm64.s index 6683e0c8bd..0e83c291a5 100644 --- a/ecc/bls12-381/fr/element_ops_arm64.s +++ b/ecc/bls12-381/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 11169582256709554223 +// We include the hash to force the Go compiler to recompile: 10009927169843392352 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-315/fr/element_ops_arm64.s b/ecc/bls24-315/fr/element_ops_arm64.s index 6683e0c8bd..0e83c291a5 100644 --- a/ecc/bls24-315/fr/element_ops_arm64.s +++ b/ecc/bls24-315/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 11169582256709554223 +// We include the hash to force the Go compiler to recompile: 10009927169843392352 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-317/fr/element_ops_arm64.s b/ecc/bls24-317/fr/element_ops_arm64.s index 6683e0c8bd..0e83c291a5 100644 --- a/ecc/bls24-317/fr/element_ops_arm64.s +++ b/ecc/bls24-317/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 11169582256709554223 +// We include the hash to force the Go compiler to recompile: 10009927169843392352 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fp/element_ops_arm64.s b/ecc/bn254/fp/element_ops_arm64.s index 6683e0c8bd..0e83c291a5 100644 --- a/ecc/bn254/fp/element_ops_arm64.s +++ b/ecc/bn254/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 11169582256709554223 +// We include the hash to force the Go compiler to recompile: 10009927169843392352 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fr/element_ops_arm64.s b/ecc/bn254/fr/element_ops_arm64.s index 6683e0c8bd..0e83c291a5 100644 --- a/ecc/bn254/fr/element_ops_arm64.s +++ b/ecc/bn254/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 11169582256709554223 +// We include the hash to force the Go compiler to recompile: 10009927169843392352 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fp/element_ops_arm64.s b/ecc/stark-curve/fp/element_ops_arm64.s index 6683e0c8bd..0e83c291a5 100644 --- a/ecc/stark-curve/fp/element_ops_arm64.s +++ b/ecc/stark-curve/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 11169582256709554223 +// We include the hash to force the Go compiler to recompile: 10009927169843392352 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fr/element_ops_arm64.s b/ecc/stark-curve/fr/element_ops_arm64.s index 6683e0c8bd..0e83c291a5 100644 --- a/ecc/stark-curve/fr/element_ops_arm64.s +++ b/ecc/stark-curve/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 11169582256709554223 +// We include the hash to force the Go compiler to recompile: 10009927169843392352 #include "../../../field/asm/element_4w_arm64.s" diff --git a/field/asm/element_4w_arm64.s b/field/asm/element_4w_arm64.s index 5e04a59e91..cae279e351 100644 --- a/field/asm/element_4w_arm64.s +++ b/field/asm/element_4w_arm64.s @@ -18,10 +18,8 @@ TEXT ·add(SB), NOSPLIT, $0-24 ADCS R11, R7, R7 // load modulus and subtract - MOVD $const_q0, R0 - MOVD $const_q1, R1 - MOVD $const_q2, R2 - MOVD $const_q3, R3 + LDP ·qElement+0(SB), (R0, R1) + LDP ·qElement+16(SB), (R2, R3) SUBS R0, R4, R0 SBCS R1, R5, R1 SBCS R2, R6, R2 diff --git a/field/generator/asm/arm64/element_ops.go b/field/generator/asm/arm64/element_ops.go index 3ca30f90c3..153e7cb43b 100644 --- a/field/generator/asm/arm64/element_ops.go +++ b/field/generator/asm/arm64/element_ops.go @@ -225,8 +225,9 @@ func (f *FFArm64) reduce(z, t []arm64.Register) { f.Comment("load modulus and subtract") - for i := 0; i < f.NbWords; i++ { - f.MOVD(f.qi(i), t[i]) + for i := 0; i < f.NbWords-1; i += 2 { + // f.MOVD(f.qi(i), t[i]) + f.LDP(f.qAt(i), t[i], t[i+1]) } f.SUBS(t[0], z[0], t[0]) for i := 1; i < f.NbWords; i++ { From 0015ca0f941a2de870196a280919ade67b0c002d Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Wed, 16 Oct 2024 16:08:42 +0000 Subject: [PATCH 06/74] checkpoint --- ecc/bls12-377/fp/element.go | 48 ------------------- ecc/bls12-377/fr/element.go | 20 -------- ecc/bls12-377/fr/element_ops_arm64.go | 8 ++++ ecc/bls12-377/fr/element_ops_arm64.s | 2 +- ecc/bls12-381/fp/element.go | 48 ------------------- ecc/bls12-381/fr/element.go | 20 -------- ecc/bls12-381/fr/element_ops_arm64.go | 8 ++++ ecc/bls12-381/fr/element_ops_arm64.s | 2 +- ecc/bls24-315/fr/element.go | 20 -------- ecc/bls24-315/fr/element_ops_arm64.go | 8 ++++ ecc/bls24-315/fr/element_ops_arm64.s | 2 +- ecc/bls24-317/fr/element.go | 20 -------- ecc/bls24-317/fr/element_ops_arm64.go | 8 ++++ ecc/bls24-317/fr/element_ops_arm64.s | 2 +- ecc/bn254/fp/element.go | 20 -------- ecc/bn254/fp/element_ops_arm64.go | 8 ++++ ecc/bn254/fp/element_ops_arm64.s | 2 +- ecc/bn254/fr/element.go | 20 -------- ecc/bn254/fr/element_ops_arm64.go | 8 ++++ ecc/bn254/fr/element_ops_arm64.s | 2 +- ecc/bw6-761/fr/element.go | 48 ------------------- ecc/stark-curve/fp/element.go | 20 -------- ecc/stark-curve/fp/element_ops_arm64.go | 8 ++++ ecc/stark-curve/fp/element_ops_arm64.s | 2 +- ecc/stark-curve/fr/element.go | 20 -------- ecc/stark-curve/fr/element_ops_arm64.go | 8 ++++ ecc/stark-curve/fr/element_ops_arm64.s | 2 +- field/asm/element_4w_arm64.s | 35 ++++++++++++-- field/generator/asm/arm64/build.go | 2 +- field/generator/asm/arm64/element_ops.go | 27 ++--------- field/generator/config/field_config.go | 2 +- .../internal/templates/element/base.go | 4 +- .../internal/templates/element/ops_asm.go | 8 ++++ 33 files changed, 122 insertions(+), 340 deletions(-) diff --git a/ecc/bls12-377/fp/element.go b/ecc/bls12-377/fp/element.go index 393f45744d..9cbbb80f5f 100644 --- a/ecc/bls12-377/fp/element.go +++ b/ecc/bls12-377/fp/element.go @@ -418,54 +418,6 @@ func (z *Element) fromMont() *Element { return z } -// Add z = x + y (mod q) -func (z *Element) Add(x, y *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], y[0], 0) - z[1], carry = bits.Add64(x[1], y[1], carry) - z[2], carry = bits.Add64(x[2], y[2], carry) - z[3], carry = bits.Add64(x[3], y[3], carry) - z[4], carry = bits.Add64(x[4], y[4], carry) - z[5], _ = bits.Add64(x[5], y[5], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], b = bits.Sub64(z[3], q3, b) - z[4], b = bits.Sub64(z[4], q4, b) - z[5], _ = bits.Sub64(z[5], q5, b) - } - return z -} - -// Double z = x + x (mod q), aka Lsh 1 -func (z *Element) Double(x *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], x[0], 0) - z[1], carry = bits.Add64(x[1], x[1], carry) - z[2], carry = bits.Add64(x[2], x[2], carry) - z[3], carry = bits.Add64(x[3], x[3], carry) - z[4], carry = bits.Add64(x[4], x[4], carry) - z[5], _ = bits.Add64(x[5], x[5], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], b = bits.Sub64(z[3], q3, b) - z[4], b = bits.Sub64(z[4], q4, b) - z[5], _ = bits.Sub64(z[5], q5, b) - } - return z -} - // Sub z = x - y (mod q) func (z *Element) Sub(x, y *Element) *Element { var b uint64 diff --git a/ecc/bls12-377/fr/element.go b/ecc/bls12-377/fr/element.go index 5e34cbd2ba..1eb3c4ad7e 100644 --- a/ecc/bls12-377/fr/element.go +++ b/ecc/bls12-377/fr/element.go @@ -393,26 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Double z = x + x (mod q), aka Lsh 1 -func (z *Element) Double(x *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], x[0], 0) - z[1], carry = bits.Add64(x[1], x[1], carry) - z[2], carry = bits.Add64(x[2], x[2], carry) - z[3], _ = bits.Add64(x[3], x[3], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Sub z = x - y (mod q) func (z *Element) Sub(x, y *Element) *Element { var b uint64 diff --git a/ecc/bls12-377/fr/element_ops_arm64.go b/ecc/bls12-377/fr/element_ops_arm64.go index 1591b31d62..604eec31d1 100644 --- a/ecc/bls12-377/fr/element_ops_arm64.go +++ b/ecc/bls12-377/fr/element_ops_arm64.go @@ -26,3 +26,11 @@ func (z *Element) Add(x, y *Element) *Element { add(z, x, y) return z } + +//go:noescape +func double(res, x *Element) + +func (z *Element) Double(x *Element) *Element { + double(z, x) + return z +} diff --git a/ecc/bls12-377/fr/element_ops_arm64.s b/ecc/bls12-377/fr/element_ops_arm64.s index 0e83c291a5..f6b3843577 100644 --- a/ecc/bls12-377/fr/element_ops_arm64.s +++ b/ecc/bls12-377/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 10009927169843392352 +// We include the hash to force the Go compiler to recompile: 7553173984042352417 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls12-381/fp/element.go b/ecc/bls12-381/fp/element.go index f0bcfe51bc..4fea510cd1 100644 --- a/ecc/bls12-381/fp/element.go +++ b/ecc/bls12-381/fp/element.go @@ -418,54 +418,6 @@ func (z *Element) fromMont() *Element { return z } -// Add z = x + y (mod q) -func (z *Element) Add(x, y *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], y[0], 0) - z[1], carry = bits.Add64(x[1], y[1], carry) - z[2], carry = bits.Add64(x[2], y[2], carry) - z[3], carry = bits.Add64(x[3], y[3], carry) - z[4], carry = bits.Add64(x[4], y[4], carry) - z[5], _ = bits.Add64(x[5], y[5], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], b = bits.Sub64(z[3], q3, b) - z[4], b = bits.Sub64(z[4], q4, b) - z[5], _ = bits.Sub64(z[5], q5, b) - } - return z -} - -// Double z = x + x (mod q), aka Lsh 1 -func (z *Element) Double(x *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], x[0], 0) - z[1], carry = bits.Add64(x[1], x[1], carry) - z[2], carry = bits.Add64(x[2], x[2], carry) - z[3], carry = bits.Add64(x[3], x[3], carry) - z[4], carry = bits.Add64(x[4], x[4], carry) - z[5], _ = bits.Add64(x[5], x[5], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], b = bits.Sub64(z[3], q3, b) - z[4], b = bits.Sub64(z[4], q4, b) - z[5], _ = bits.Sub64(z[5], q5, b) - } - return z -} - // Sub z = x - y (mod q) func (z *Element) Sub(x, y *Element) *Element { var b uint64 diff --git a/ecc/bls12-381/fr/element.go b/ecc/bls12-381/fr/element.go index 8cf8b8bd78..391d8052c3 100644 --- a/ecc/bls12-381/fr/element.go +++ b/ecc/bls12-381/fr/element.go @@ -393,26 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Double z = x + x (mod q), aka Lsh 1 -func (z *Element) Double(x *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], x[0], 0) - z[1], carry = bits.Add64(x[1], x[1], carry) - z[2], carry = bits.Add64(x[2], x[2], carry) - z[3], _ = bits.Add64(x[3], x[3], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Sub z = x - y (mod q) func (z *Element) Sub(x, y *Element) *Element { var b uint64 diff --git a/ecc/bls12-381/fr/element_ops_arm64.go b/ecc/bls12-381/fr/element_ops_arm64.go index 1591b31d62..604eec31d1 100644 --- a/ecc/bls12-381/fr/element_ops_arm64.go +++ b/ecc/bls12-381/fr/element_ops_arm64.go @@ -26,3 +26,11 @@ func (z *Element) Add(x, y *Element) *Element { add(z, x, y) return z } + +//go:noescape +func double(res, x *Element) + +func (z *Element) Double(x *Element) *Element { + double(z, x) + return z +} diff --git a/ecc/bls12-381/fr/element_ops_arm64.s b/ecc/bls12-381/fr/element_ops_arm64.s index 0e83c291a5..f6b3843577 100644 --- a/ecc/bls12-381/fr/element_ops_arm64.s +++ b/ecc/bls12-381/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 10009927169843392352 +// We include the hash to force the Go compiler to recompile: 7553173984042352417 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-315/fr/element.go b/ecc/bls24-315/fr/element.go index cae603ade5..1bce0057e5 100644 --- a/ecc/bls24-315/fr/element.go +++ b/ecc/bls24-315/fr/element.go @@ -393,26 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Double z = x + x (mod q), aka Lsh 1 -func (z *Element) Double(x *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], x[0], 0) - z[1], carry = bits.Add64(x[1], x[1], carry) - z[2], carry = bits.Add64(x[2], x[2], carry) - z[3], _ = bits.Add64(x[3], x[3], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Sub z = x - y (mod q) func (z *Element) Sub(x, y *Element) *Element { var b uint64 diff --git a/ecc/bls24-315/fr/element_ops_arm64.go b/ecc/bls24-315/fr/element_ops_arm64.go index 1591b31d62..604eec31d1 100644 --- a/ecc/bls24-315/fr/element_ops_arm64.go +++ b/ecc/bls24-315/fr/element_ops_arm64.go @@ -26,3 +26,11 @@ func (z *Element) Add(x, y *Element) *Element { add(z, x, y) return z } + +//go:noescape +func double(res, x *Element) + +func (z *Element) Double(x *Element) *Element { + double(z, x) + return z +} diff --git a/ecc/bls24-315/fr/element_ops_arm64.s b/ecc/bls24-315/fr/element_ops_arm64.s index 0e83c291a5..f6b3843577 100644 --- a/ecc/bls24-315/fr/element_ops_arm64.s +++ b/ecc/bls24-315/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 10009927169843392352 +// We include the hash to force the Go compiler to recompile: 7553173984042352417 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-317/fr/element.go b/ecc/bls24-317/fr/element.go index 790491630c..59fa5911bd 100644 --- a/ecc/bls24-317/fr/element.go +++ b/ecc/bls24-317/fr/element.go @@ -393,26 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Double z = x + x (mod q), aka Lsh 1 -func (z *Element) Double(x *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], x[0], 0) - z[1], carry = bits.Add64(x[1], x[1], carry) - z[2], carry = bits.Add64(x[2], x[2], carry) - z[3], _ = bits.Add64(x[3], x[3], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Sub z = x - y (mod q) func (z *Element) Sub(x, y *Element) *Element { var b uint64 diff --git a/ecc/bls24-317/fr/element_ops_arm64.go b/ecc/bls24-317/fr/element_ops_arm64.go index 1591b31d62..604eec31d1 100644 --- a/ecc/bls24-317/fr/element_ops_arm64.go +++ b/ecc/bls24-317/fr/element_ops_arm64.go @@ -26,3 +26,11 @@ func (z *Element) Add(x, y *Element) *Element { add(z, x, y) return z } + +//go:noescape +func double(res, x *Element) + +func (z *Element) Double(x *Element) *Element { + double(z, x) + return z +} diff --git a/ecc/bls24-317/fr/element_ops_arm64.s b/ecc/bls24-317/fr/element_ops_arm64.s index 0e83c291a5..f6b3843577 100644 --- a/ecc/bls24-317/fr/element_ops_arm64.s +++ b/ecc/bls24-317/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 10009927169843392352 +// We include the hash to force the Go compiler to recompile: 7553173984042352417 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fp/element.go b/ecc/bn254/fp/element.go index 87323af2bc..2b97a0950a 100644 --- a/ecc/bn254/fp/element.go +++ b/ecc/bn254/fp/element.go @@ -393,26 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Double z = x + x (mod q), aka Lsh 1 -func (z *Element) Double(x *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], x[0], 0) - z[1], carry = bits.Add64(x[1], x[1], carry) - z[2], carry = bits.Add64(x[2], x[2], carry) - z[3], _ = bits.Add64(x[3], x[3], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Sub z = x - y (mod q) func (z *Element) Sub(x, y *Element) *Element { var b uint64 diff --git a/ecc/bn254/fp/element_ops_arm64.go b/ecc/bn254/fp/element_ops_arm64.go index c7a28f43c7..c4a02f2a43 100644 --- a/ecc/bn254/fp/element_ops_arm64.go +++ b/ecc/bn254/fp/element_ops_arm64.go @@ -26,3 +26,11 @@ func (z *Element) Add(x, y *Element) *Element { add(z, x, y) return z } + +//go:noescape +func double(res, x *Element) + +func (z *Element) Double(x *Element) *Element { + double(z, x) + return z +} diff --git a/ecc/bn254/fp/element_ops_arm64.s b/ecc/bn254/fp/element_ops_arm64.s index 0e83c291a5..f6b3843577 100644 --- a/ecc/bn254/fp/element_ops_arm64.s +++ b/ecc/bn254/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 10009927169843392352 +// We include the hash to force the Go compiler to recompile: 7553173984042352417 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fr/element.go b/ecc/bn254/fr/element.go index 3da98a1b57..07b7b6a2b5 100644 --- a/ecc/bn254/fr/element.go +++ b/ecc/bn254/fr/element.go @@ -393,26 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Double z = x + x (mod q), aka Lsh 1 -func (z *Element) Double(x *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], x[0], 0) - z[1], carry = bits.Add64(x[1], x[1], carry) - z[2], carry = bits.Add64(x[2], x[2], carry) - z[3], _ = bits.Add64(x[3], x[3], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Sub z = x - y (mod q) func (z *Element) Sub(x, y *Element) *Element { var b uint64 diff --git a/ecc/bn254/fr/element_ops_arm64.go b/ecc/bn254/fr/element_ops_arm64.go index 1591b31d62..604eec31d1 100644 --- a/ecc/bn254/fr/element_ops_arm64.go +++ b/ecc/bn254/fr/element_ops_arm64.go @@ -26,3 +26,11 @@ func (z *Element) Add(x, y *Element) *Element { add(z, x, y) return z } + +//go:noescape +func double(res, x *Element) + +func (z *Element) Double(x *Element) *Element { + double(z, x) + return z +} diff --git a/ecc/bn254/fr/element_ops_arm64.s b/ecc/bn254/fr/element_ops_arm64.s index 0e83c291a5..f6b3843577 100644 --- a/ecc/bn254/fr/element_ops_arm64.s +++ b/ecc/bn254/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 10009927169843392352 +// We include the hash to force the Go compiler to recompile: 7553173984042352417 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bw6-761/fr/element.go b/ecc/bw6-761/fr/element.go index 6784bc911f..a90d04bb53 100644 --- a/ecc/bw6-761/fr/element.go +++ b/ecc/bw6-761/fr/element.go @@ -418,54 +418,6 @@ func (z *Element) fromMont() *Element { return z } -// Add z = x + y (mod q) -func (z *Element) Add(x, y *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], y[0], 0) - z[1], carry = bits.Add64(x[1], y[1], carry) - z[2], carry = bits.Add64(x[2], y[2], carry) - z[3], carry = bits.Add64(x[3], y[3], carry) - z[4], carry = bits.Add64(x[4], y[4], carry) - z[5], _ = bits.Add64(x[5], y[5], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], b = bits.Sub64(z[3], q3, b) - z[4], b = bits.Sub64(z[4], q4, b) - z[5], _ = bits.Sub64(z[5], q5, b) - } - return z -} - -// Double z = x + x (mod q), aka Lsh 1 -func (z *Element) Double(x *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], x[0], 0) - z[1], carry = bits.Add64(x[1], x[1], carry) - z[2], carry = bits.Add64(x[2], x[2], carry) - z[3], carry = bits.Add64(x[3], x[3], carry) - z[4], carry = bits.Add64(x[4], x[4], carry) - z[5], _ = bits.Add64(x[5], x[5], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], b = bits.Sub64(z[3], q3, b) - z[4], b = bits.Sub64(z[4], q4, b) - z[5], _ = bits.Sub64(z[5], q5, b) - } - return z -} - // Sub z = x - y (mod q) func (z *Element) Sub(x, y *Element) *Element { var b uint64 diff --git a/ecc/stark-curve/fp/element.go b/ecc/stark-curve/fp/element.go index 7a68362d40..fc22d531e1 100644 --- a/ecc/stark-curve/fp/element.go +++ b/ecc/stark-curve/fp/element.go @@ -393,26 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Double z = x + x (mod q), aka Lsh 1 -func (z *Element) Double(x *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], x[0], 0) - z[1], carry = bits.Add64(x[1], x[1], carry) - z[2], carry = bits.Add64(x[2], x[2], carry) - z[3], _ = bits.Add64(x[3], x[3], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Sub z = x - y (mod q) func (z *Element) Sub(x, y *Element) *Element { var b uint64 diff --git a/ecc/stark-curve/fp/element_ops_arm64.go b/ecc/stark-curve/fp/element_ops_arm64.go index c7a28f43c7..c4a02f2a43 100644 --- a/ecc/stark-curve/fp/element_ops_arm64.go +++ b/ecc/stark-curve/fp/element_ops_arm64.go @@ -26,3 +26,11 @@ func (z *Element) Add(x, y *Element) *Element { add(z, x, y) return z } + +//go:noescape +func double(res, x *Element) + +func (z *Element) Double(x *Element) *Element { + double(z, x) + return z +} diff --git a/ecc/stark-curve/fp/element_ops_arm64.s b/ecc/stark-curve/fp/element_ops_arm64.s index 0e83c291a5..f6b3843577 100644 --- a/ecc/stark-curve/fp/element_ops_arm64.s +++ b/ecc/stark-curve/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 10009927169843392352 +// We include the hash to force the Go compiler to recompile: 7553173984042352417 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fr/element.go b/ecc/stark-curve/fr/element.go index 601d7ba831..f393726632 100644 --- a/ecc/stark-curve/fr/element.go +++ b/ecc/stark-curve/fr/element.go @@ -393,26 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Double z = x + x (mod q), aka Lsh 1 -func (z *Element) Double(x *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], x[0], 0) - z[1], carry = bits.Add64(x[1], x[1], carry) - z[2], carry = bits.Add64(x[2], x[2], carry) - z[3], _ = bits.Add64(x[3], x[3], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Sub z = x - y (mod q) func (z *Element) Sub(x, y *Element) *Element { var b uint64 diff --git a/ecc/stark-curve/fr/element_ops_arm64.go b/ecc/stark-curve/fr/element_ops_arm64.go index 1591b31d62..604eec31d1 100644 --- a/ecc/stark-curve/fr/element_ops_arm64.go +++ b/ecc/stark-curve/fr/element_ops_arm64.go @@ -26,3 +26,11 @@ func (z *Element) Add(x, y *Element) *Element { add(z, x, y) return z } + +//go:noescape +func double(res, x *Element) + +func (z *Element) Double(x *Element) *Element { + double(z, x) + return z +} diff --git a/ecc/stark-curve/fr/element_ops_arm64.s b/ecc/stark-curve/fr/element_ops_arm64.s index 0e83c291a5..f6b3843577 100644 --- a/ecc/stark-curve/fr/element_ops_arm64.s +++ b/ecc/stark-curve/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 10009927169843392352 +// We include the hash to force the Go compiler to recompile: 7553173984042352417 #include "../../../field/asm/element_4w_arm64.s" diff --git a/field/asm/element_4w_arm64.s b/field/asm/element_4w_arm64.s index cae279e351..7d5ebd3a80 100644 --- a/field/asm/element_4w_arm64.s +++ b/field/asm/element_4w_arm64.s @@ -5,9 +5,7 @@ // add(res, x, y *Element) TEXT ·add(SB), NOSPLIT, $0-24 - LDP x+8(FP), (R12, R13) - - // load operands and add mod 2^r + LDP x+8(FP), (R12, R13) LDP 0(R12), (R8, R9) LDP 0(R13), (R4, R5) LDP 16(R12), (R10, R11) @@ -36,3 +34,34 @@ TEXT ·add(SB), NOSPLIT, $0-24 STP (R4, R5), 0(R14) STP (R6, R7), 16(R14) RET + +// double(res, x *Element) +TEXT ·double(SB), NOSPLIT, $0-16 + LDP res+0(FP), (R5, R4) + + // load operands and add mod 2^r + LDP 0(R4), (R0, R1) + ADDS R0, R0, R0 + ADCS R1, R1, R1 + LDP 16(R4), (R2, R3) + ADCS R2, R2, R2 + ADCS R3, R3, R3 + + // load modulus and subtract + LDP ·qElement+0(SB), (R6, R7) + LDP ·qElement+16(SB), (R8, R9) + SUBS R6, R0, R6 + SBCS R7, R1, R7 + SBCS R8, R2, R8 + SBCS R9, R3, R9 + + // reduce if necessary + CSEL CS, R6, R0, R0 + CSEL CS, R7, R1, R1 + CSEL CS, R8, R2, R2 + CSEL CS, R9, R3, R3 + + // store + STP (R0, R1), 0(R5) + STP (R2, R3), 16(R5) + RET diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index 0839fe9ef8..d9cf214c2b 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -108,8 +108,8 @@ func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { } f.generateAdd() + f.generateDouble() // f.generateSub() - // f.generateDouble() // f.generateNeg() return nil diff --git a/field/generator/asm/arm64/element_ops.go b/field/generator/asm/arm64/element_ops.go index 153e7cb43b..ea52f948ce 100644 --- a/field/generator/asm/arm64/element_ops.go +++ b/field/generator/asm/arm64/element_ops.go @@ -20,7 +20,6 @@ import ( func (f *FFArm64) generateAdd() { f.Comment("add(res, x, y *Element)") - //stackSize := f.StackSize(f.NbWords*2, 0, 0) registers := f.FnHeader("add", 0, 24) defer f.AssertCleanStack(0, 0) @@ -33,7 +32,6 @@ func (f *FFArm64) generateAdd() { zPtr := registers.Pop() f.LDP("x+8(FP)", xPtr, yPtr) - f.Comment("load operands and add mod 2^r") for i := 0; i < f.NbWords-1; i += 2 { f.LDP(xPtr.At(i), x[i], x[i+1]) @@ -50,10 +48,7 @@ func (f *FFArm64) generateAdd() { f.Comment("store") f.MOVD("res+0(FP)", zPtr) - - for i := 0; i < f.NbWords-1; i += 2 { - f.STP(z[i], z[i+1], zPtr.At(i)) - } + f.storeVector(z, zPtr) f.RET() @@ -83,11 +78,6 @@ func (f *FFArm64) generateDouble() { f.ADCS(z[i+1], z[i+1], z[i+1]) } - if f.NbWords%2 == 1 { - i := f.NbWords - 1 - f.MOVD(xPtr.At(i), z[i]) - op0(z[i], z[i], z[i]) - } registers.Push(xPtr) t := registers.PopN(f.NbWords) @@ -144,13 +134,6 @@ func (f *FFArm64) generateSub() { f.CSEL("CS", zero, t[i+1], t[i+1]) } - if f.NbWords%2 == 1 { - i := f.NbWords - 1 - f.MOVD(f.qAt(i), t[i]) - - f.CSEL("CS", zero, t[i], t[i]) - } - registers.Push(zero) f.Comment("augment (or not)") @@ -226,7 +209,6 @@ func (f *FFArm64) reduce(z, t []arm64.Register) { f.Comment("load modulus and subtract") for i := 0; i < f.NbWords-1; i += 2 { - // f.MOVD(f.qi(i), t[i]) f.LDP(f.qAt(i), t[i], t[i+1]) } f.SUBS(t[0], z[0], t[0]) @@ -235,12 +217,13 @@ func (f *FFArm64) reduce(z, t []arm64.Register) { } f.Comment("reduce if necessary") - for i := 0; i < f.NbWords; i++ { f.CSEL("CS", t[i], z[i], z[i]) } } -func (f *FFArm64) storeVector(vector interface{}, baseAddress arm64.Register) { - // f.callTemplate("storeVector", toInterfaceSlice(baseAddress, vector)...) +func (f *FFArm64) storeVector(z []arm64.Register, zPtr arm64.Register) { + for i := 0; i < f.NbWords-1; i += 2 { + f.STP(z[i], z[i+1], zPtr.At(i)) + } } diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index 9b0d8d47d0..f9617f6e1d 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -264,7 +264,7 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) // asm code generation for moduli with more than 6 words can be optimized further F.ASM = F.NoCarry && F.NbWords <= 12 && F.NbWords > 1 F.ASMVector = F.ASM && F.NbWords == 4 && F.NbBits > 225 - F.ASMArm = F.ASMVector + F.ASMArm = F.ASMVector || (F.NbWords == 6) // setting Mu 2^288 / q if F.NbWords == 4 { diff --git a/field/generator/internal/templates/element/base.go b/field/generator/internal/templates/element/base.go index f54a5181d6..8cd623514a 100644 --- a/field/generator/internal/templates/element/base.go +++ b/field/generator/internal/templates/element/base.go @@ -431,7 +431,7 @@ func (z *{{.ElementName}}) Add( x, y *{{.ElementName}}) *{{.ElementName}} { {{- end}} return z } -{{- end}} + // Double z = x + x (mod q), aka Lsh 1 func (z *{{.ElementName}}) Double( x *{{.ElementName}}) *{{.ElementName}} { @@ -474,6 +474,8 @@ func (z *{{.ElementName}}) Double( x *{{.ElementName}}) *{{.ElementName}} { return z } +{{- end}} + // Sub z = x - y (mod q) func (z *{{.ElementName}}) Sub( x, y *{{.ElementName}}) *{{.ElementName}} { diff --git a/field/generator/internal/templates/element/ops_asm.go b/field/generator/internal/templates/element/ops_asm.go index e07c71e628..896b19b857 100644 --- a/field/generator/internal/templates/element/ops_asm.go +++ b/field/generator/internal/templates/element/ops_asm.go @@ -206,5 +206,13 @@ func (z *{{.ElementName}}) Add( x, y *{{.ElementName}}) *{{.ElementName}} { add(z,x,y) return z } + +//go:noescape +func double(res, x *{{.ElementName}}) + +func (z *{{.ElementName}}) Double(x *{{.ElementName}}) *{{.ElementName}} { + double(z,x) + return z +} {{end}} ` From e38579fd366921f702b5b6eb93a6923b46babe29 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Wed, 16 Oct 2024 16:11:52 +0000 Subject: [PATCH 07/74] checkpoint --- ecc/bls12-377/fr/element_ops_arm64.s | 2 +- ecc/bls12-381/fr/element_ops_arm64.s | 2 +- ecc/bls24-315/fr/element_ops_arm64.s | 2 +- ecc/bls24-317/fr/element_ops_arm64.s | 2 +- ecc/bn254/fp/element_ops_arm64.s | 2 +- ecc/bn254/fr/element_ops_arm64.s | 2 +- ecc/stark-curve/fp/element_ops_arm64.s | 2 +- ecc/stark-curve/fr/element_ops_arm64.s | 2 +- field/asm/element_4w_arm64.s | 36 +++++++++++------------- field/generator/asm/arm64/element_ops.go | 19 ++++--------- 10 files changed, 30 insertions(+), 41 deletions(-) diff --git a/ecc/bls12-377/fr/element_ops_arm64.s b/ecc/bls12-377/fr/element_ops_arm64.s index f6b3843577..fad68ec28f 100644 --- a/ecc/bls12-377/fr/element_ops_arm64.s +++ b/ecc/bls12-377/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 7553173984042352417 +// We include the hash to force the Go compiler to recompile: 1365634154298741738 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls12-381/fr/element_ops_arm64.s b/ecc/bls12-381/fr/element_ops_arm64.s index f6b3843577..fad68ec28f 100644 --- a/ecc/bls12-381/fr/element_ops_arm64.s +++ b/ecc/bls12-381/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 7553173984042352417 +// We include the hash to force the Go compiler to recompile: 1365634154298741738 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-315/fr/element_ops_arm64.s b/ecc/bls24-315/fr/element_ops_arm64.s index f6b3843577..fad68ec28f 100644 --- a/ecc/bls24-315/fr/element_ops_arm64.s +++ b/ecc/bls24-315/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 7553173984042352417 +// We include the hash to force the Go compiler to recompile: 1365634154298741738 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-317/fr/element_ops_arm64.s b/ecc/bls24-317/fr/element_ops_arm64.s index f6b3843577..fad68ec28f 100644 --- a/ecc/bls24-317/fr/element_ops_arm64.s +++ b/ecc/bls24-317/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 7553173984042352417 +// We include the hash to force the Go compiler to recompile: 1365634154298741738 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fp/element_ops_arm64.s b/ecc/bn254/fp/element_ops_arm64.s index f6b3843577..fad68ec28f 100644 --- a/ecc/bn254/fp/element_ops_arm64.s +++ b/ecc/bn254/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 7553173984042352417 +// We include the hash to force the Go compiler to recompile: 1365634154298741738 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fr/element_ops_arm64.s b/ecc/bn254/fr/element_ops_arm64.s index f6b3843577..fad68ec28f 100644 --- a/ecc/bn254/fr/element_ops_arm64.s +++ b/ecc/bn254/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 7553173984042352417 +// We include the hash to force the Go compiler to recompile: 1365634154298741738 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fp/element_ops_arm64.s b/ecc/stark-curve/fp/element_ops_arm64.s index f6b3843577..fad68ec28f 100644 --- a/ecc/stark-curve/fp/element_ops_arm64.s +++ b/ecc/stark-curve/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 7553173984042352417 +// We include the hash to force the Go compiler to recompile: 1365634154298741738 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fr/element_ops_arm64.s b/ecc/stark-curve/fr/element_ops_arm64.s index f6b3843577..fad68ec28f 100644 --- a/ecc/stark-curve/fr/element_ops_arm64.s +++ b/ecc/stark-curve/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 7553173984042352417 +// We include the hash to force the Go compiler to recompile: 1365634154298741738 #include "../../../field/asm/element_4w_arm64.s" diff --git a/field/asm/element_4w_arm64.s b/field/asm/element_4w_arm64.s index 7d5ebd3a80..c00eeb970d 100644 --- a/field/asm/element_4w_arm64.s +++ b/field/asm/element_4w_arm64.s @@ -37,31 +37,27 @@ TEXT ·add(SB), NOSPLIT, $0-24 // double(res, x *Element) TEXT ·double(SB), NOSPLIT, $0-16 - LDP res+0(FP), (R5, R4) - - // load operands and add mod 2^r - LDP 0(R4), (R0, R1) - ADDS R0, R0, R0 - ADCS R1, R1, R1 - LDP 16(R4), (R2, R3) - ADCS R2, R2, R2 + LDP res+0(FP), (R1, R0) + LDP 0(R0), (R2, R3) + LDP 16(R0), (R4, R5) + ADDS R2, R2, R2 ADCS R3, R3, R3 + ADCS R4, R4, R4 + ADCS R5, R5, R5 // load modulus and subtract LDP ·qElement+0(SB), (R6, R7) LDP ·qElement+16(SB), (R8, R9) - SUBS R6, R0, R6 - SBCS R7, R1, R7 - SBCS R8, R2, R8 - SBCS R9, R3, R9 + SUBS R6, R2, R6 + SBCS R7, R3, R7 + SBCS R8, R4, R8 + SBCS R9, R5, R9 // reduce if necessary - CSEL CS, R6, R0, R0 - CSEL CS, R7, R1, R1 - CSEL CS, R8, R2, R2 - CSEL CS, R9, R3, R3 - - // store - STP (R0, R1), 0(R5) - STP (R2, R3), 16(R5) + CSEL CS, R6, R2, R2 + CSEL CS, R7, R3, R3 + CSEL CS, R8, R4, R4 + CSEL CS, R9, R5, R5 + STP (R2, R3), 0(R1) + STP (R4, R5), 16(R1) RET diff --git a/field/generator/asm/arm64/element_ops.go b/field/generator/asm/arm64/element_ops.go index ea52f948ce..0a5923f7fb 100644 --- a/field/generator/asm/arm64/element_ops.go +++ b/field/generator/asm/arm64/element_ops.go @@ -60,31 +60,24 @@ func (f *FFArm64) generateDouble() { defer f.AssertCleanStack(0, 0) // registers - z := registers.PopN(f.NbWords) xPtr := registers.Pop() zPtr := registers.Pop() - //ops := registers.PopN(2) + z := registers.PopN(f.NbWords) + t := registers.PopN(f.NbWords) f.LDP("res+0(FP)", zPtr, xPtr) - f.Comment("load operands and add mod 2^r") - op0 := f.ADDS for i := 0; i < f.NbWords-1; i += 2 { f.LDP(xPtr.At(i), z[i], z[i+1]) - - op0(z[i], z[i], z[i]) - op0 = f.ADCS - - f.ADCS(z[i+1], z[i+1], z[i+1]) } - registers.Push(xPtr) + f.ADDS(z[0], z[0], z[0]) + for i := 1; i < f.NbWords; i++ { + f.ADCS(z[i], z[i], z[i]) + } - t := registers.PopN(f.NbWords) f.reduce(z, t) - registers.Push(t...) - f.Comment("store") f.storeVector(z, zPtr) f.RET() From a3343879aea257cb7ec9fc70a9c269b29787fd4c Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Wed, 16 Oct 2024 16:25:17 +0000 Subject: [PATCH 08/74] checkpoint --- ecc/bls12-377/fp/element.go | 21 ------ ecc/bls12-377/fr/element.go | 17 ----- ecc/bls12-377/fr/element_ops_arm64.go | 8 +++ ecc/bls12-377/fr/element_ops_arm64.s | 2 +- ecc/bls12-381/fp/element.go | 21 ------ ecc/bls12-381/fr/element.go | 17 ----- ecc/bls12-381/fr/element_ops_arm64.go | 8 +++ ecc/bls12-381/fr/element_ops_arm64.s | 2 +- ecc/bls24-315/fr/element.go | 17 ----- ecc/bls24-315/fr/element_ops_arm64.go | 8 +++ ecc/bls24-315/fr/element_ops_arm64.s | 2 +- ecc/bls24-317/fr/element.go | 17 ----- ecc/bls24-317/fr/element_ops_arm64.go | 8 +++ ecc/bls24-317/fr/element_ops_arm64.s | 2 +- ecc/bn254/fp/element.go | 17 ----- ecc/bn254/fp/element_ops_arm64.go | 8 +++ ecc/bn254/fp/element_ops_arm64.s | 2 +- ecc/bn254/fr/element.go | 17 ----- ecc/bn254/fr/element_ops_arm64.go | 8 +++ ecc/bn254/fr/element_ops_arm64.s | 2 +- ecc/bw6-761/fr/element.go | 21 ------ ecc/stark-curve/fp/element.go | 17 ----- ecc/stark-curve/fp/element_ops_arm64.go | 8 +++ ecc/stark-curve/fp/element_ops_arm64.s | 2 +- ecc/stark-curve/fr/element.go | 17 ----- ecc/stark-curve/fr/element_ops_arm64.go | 8 +++ ecc/stark-curve/fr/element_ops_arm64.s | 2 +- field/asm/element_4w_arm64.s | 32 ++++++++- field/generator/asm/arm64/build.go | 2 +- field/generator/asm/arm64/element_ops.go | 72 ++++++++----------- .../internal/templates/element/base.go | 4 +- .../internal/templates/element/ops_asm.go | 9 +++ 32 files changed, 145 insertions(+), 253 deletions(-) diff --git a/ecc/bls12-377/fp/element.go b/ecc/bls12-377/fp/element.go index 9cbbb80f5f..408b94653d 100644 --- a/ecc/bls12-377/fp/element.go +++ b/ecc/bls12-377/fp/element.go @@ -418,27 +418,6 @@ func (z *Element) fromMont() *Element { return z } -// Sub z = x - y (mod q) -func (z *Element) Sub(x, y *Element) *Element { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) - z[1], b = bits.Sub64(x[1], y[1], b) - z[2], b = bits.Sub64(x[2], y[2], b) - z[3], b = bits.Sub64(x[3], y[3], b) - z[4], b = bits.Sub64(x[4], y[4], b) - z[5], b = bits.Sub64(x[5], y[5], b) - if b != 0 { - var c uint64 - z[0], c = bits.Add64(z[0], q0, 0) - z[1], c = bits.Add64(z[1], q1, c) - z[2], c = bits.Add64(z[2], q2, c) - z[3], c = bits.Add64(z[3], q3, c) - z[4], c = bits.Add64(z[4], q4, c) - z[5], _ = bits.Add64(z[5], q5, c) - } - return z -} - // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bls12-377/fr/element.go b/ecc/bls12-377/fr/element.go index 1eb3c4ad7e..e37805eaa6 100644 --- a/ecc/bls12-377/fr/element.go +++ b/ecc/bls12-377/fr/element.go @@ -393,23 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Sub z = x - y (mod q) -func (z *Element) Sub(x, y *Element) *Element { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) - z[1], b = bits.Sub64(x[1], y[1], b) - z[2], b = bits.Sub64(x[2], y[2], b) - z[3], b = bits.Sub64(x[3], y[3], b) - if b != 0 { - var c uint64 - z[0], c = bits.Add64(z[0], q0, 0) - z[1], c = bits.Add64(z[1], q1, c) - z[2], c = bits.Add64(z[2], q2, c) - z[3], _ = bits.Add64(z[3], q3, c) - } - return z -} - // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bls12-377/fr/element_ops_arm64.go b/ecc/bls12-377/fr/element_ops_arm64.go index 604eec31d1..36cb58106c 100644 --- a/ecc/bls12-377/fr/element_ops_arm64.go +++ b/ecc/bls12-377/fr/element_ops_arm64.go @@ -34,3 +34,11 @@ func (z *Element) Double(x *Element) *Element { double(z, x) return z } + +//go:noescape +func sub(res, x, y *Element) + +func (z *Element) Sub(x, y *Element) *Element { + sub(z, x, y) + return z +} diff --git a/ecc/bls12-377/fr/element_ops_arm64.s b/ecc/bls12-377/fr/element_ops_arm64.s index fad68ec28f..34606e5f07 100644 --- a/ecc/bls12-377/fr/element_ops_arm64.s +++ b/ecc/bls12-377/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 1365634154298741738 +// We include the hash to force the Go compiler to recompile: 16055964597816771835 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls12-381/fp/element.go b/ecc/bls12-381/fp/element.go index 4fea510cd1..1609ca9524 100644 --- a/ecc/bls12-381/fp/element.go +++ b/ecc/bls12-381/fp/element.go @@ -418,27 +418,6 @@ func (z *Element) fromMont() *Element { return z } -// Sub z = x - y (mod q) -func (z *Element) Sub(x, y *Element) *Element { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) - z[1], b = bits.Sub64(x[1], y[1], b) - z[2], b = bits.Sub64(x[2], y[2], b) - z[3], b = bits.Sub64(x[3], y[3], b) - z[4], b = bits.Sub64(x[4], y[4], b) - z[5], b = bits.Sub64(x[5], y[5], b) - if b != 0 { - var c uint64 - z[0], c = bits.Add64(z[0], q0, 0) - z[1], c = bits.Add64(z[1], q1, c) - z[2], c = bits.Add64(z[2], q2, c) - z[3], c = bits.Add64(z[3], q3, c) - z[4], c = bits.Add64(z[4], q4, c) - z[5], _ = bits.Add64(z[5], q5, c) - } - return z -} - // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bls12-381/fr/element.go b/ecc/bls12-381/fr/element.go index 391d8052c3..bcc9c6e251 100644 --- a/ecc/bls12-381/fr/element.go +++ b/ecc/bls12-381/fr/element.go @@ -393,23 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Sub z = x - y (mod q) -func (z *Element) Sub(x, y *Element) *Element { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) - z[1], b = bits.Sub64(x[1], y[1], b) - z[2], b = bits.Sub64(x[2], y[2], b) - z[3], b = bits.Sub64(x[3], y[3], b) - if b != 0 { - var c uint64 - z[0], c = bits.Add64(z[0], q0, 0) - z[1], c = bits.Add64(z[1], q1, c) - z[2], c = bits.Add64(z[2], q2, c) - z[3], _ = bits.Add64(z[3], q3, c) - } - return z -} - // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bls12-381/fr/element_ops_arm64.go b/ecc/bls12-381/fr/element_ops_arm64.go index 604eec31d1..36cb58106c 100644 --- a/ecc/bls12-381/fr/element_ops_arm64.go +++ b/ecc/bls12-381/fr/element_ops_arm64.go @@ -34,3 +34,11 @@ func (z *Element) Double(x *Element) *Element { double(z, x) return z } + +//go:noescape +func sub(res, x, y *Element) + +func (z *Element) Sub(x, y *Element) *Element { + sub(z, x, y) + return z +} diff --git a/ecc/bls12-381/fr/element_ops_arm64.s b/ecc/bls12-381/fr/element_ops_arm64.s index fad68ec28f..34606e5f07 100644 --- a/ecc/bls12-381/fr/element_ops_arm64.s +++ b/ecc/bls12-381/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 1365634154298741738 +// We include the hash to force the Go compiler to recompile: 16055964597816771835 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-315/fr/element.go b/ecc/bls24-315/fr/element.go index 1bce0057e5..f565c96360 100644 --- a/ecc/bls24-315/fr/element.go +++ b/ecc/bls24-315/fr/element.go @@ -393,23 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Sub z = x - y (mod q) -func (z *Element) Sub(x, y *Element) *Element { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) - z[1], b = bits.Sub64(x[1], y[1], b) - z[2], b = bits.Sub64(x[2], y[2], b) - z[3], b = bits.Sub64(x[3], y[3], b) - if b != 0 { - var c uint64 - z[0], c = bits.Add64(z[0], q0, 0) - z[1], c = bits.Add64(z[1], q1, c) - z[2], c = bits.Add64(z[2], q2, c) - z[3], _ = bits.Add64(z[3], q3, c) - } - return z -} - // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bls24-315/fr/element_ops_arm64.go b/ecc/bls24-315/fr/element_ops_arm64.go index 604eec31d1..36cb58106c 100644 --- a/ecc/bls24-315/fr/element_ops_arm64.go +++ b/ecc/bls24-315/fr/element_ops_arm64.go @@ -34,3 +34,11 @@ func (z *Element) Double(x *Element) *Element { double(z, x) return z } + +//go:noescape +func sub(res, x, y *Element) + +func (z *Element) Sub(x, y *Element) *Element { + sub(z, x, y) + return z +} diff --git a/ecc/bls24-315/fr/element_ops_arm64.s b/ecc/bls24-315/fr/element_ops_arm64.s index fad68ec28f..34606e5f07 100644 --- a/ecc/bls24-315/fr/element_ops_arm64.s +++ b/ecc/bls24-315/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 1365634154298741738 +// We include the hash to force the Go compiler to recompile: 16055964597816771835 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-317/fr/element.go b/ecc/bls24-317/fr/element.go index 59fa5911bd..f00c1ed57a 100644 --- a/ecc/bls24-317/fr/element.go +++ b/ecc/bls24-317/fr/element.go @@ -393,23 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Sub z = x - y (mod q) -func (z *Element) Sub(x, y *Element) *Element { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) - z[1], b = bits.Sub64(x[1], y[1], b) - z[2], b = bits.Sub64(x[2], y[2], b) - z[3], b = bits.Sub64(x[3], y[3], b) - if b != 0 { - var c uint64 - z[0], c = bits.Add64(z[0], q0, 0) - z[1], c = bits.Add64(z[1], q1, c) - z[2], c = bits.Add64(z[2], q2, c) - z[3], _ = bits.Add64(z[3], q3, c) - } - return z -} - // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bls24-317/fr/element_ops_arm64.go b/ecc/bls24-317/fr/element_ops_arm64.go index 604eec31d1..36cb58106c 100644 --- a/ecc/bls24-317/fr/element_ops_arm64.go +++ b/ecc/bls24-317/fr/element_ops_arm64.go @@ -34,3 +34,11 @@ func (z *Element) Double(x *Element) *Element { double(z, x) return z } + +//go:noescape +func sub(res, x, y *Element) + +func (z *Element) Sub(x, y *Element) *Element { + sub(z, x, y) + return z +} diff --git a/ecc/bls24-317/fr/element_ops_arm64.s b/ecc/bls24-317/fr/element_ops_arm64.s index fad68ec28f..34606e5f07 100644 --- a/ecc/bls24-317/fr/element_ops_arm64.s +++ b/ecc/bls24-317/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 1365634154298741738 +// We include the hash to force the Go compiler to recompile: 16055964597816771835 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fp/element.go b/ecc/bn254/fp/element.go index 2b97a0950a..e58b0ac6b2 100644 --- a/ecc/bn254/fp/element.go +++ b/ecc/bn254/fp/element.go @@ -393,23 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Sub z = x - y (mod q) -func (z *Element) Sub(x, y *Element) *Element { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) - z[1], b = bits.Sub64(x[1], y[1], b) - z[2], b = bits.Sub64(x[2], y[2], b) - z[3], b = bits.Sub64(x[3], y[3], b) - if b != 0 { - var c uint64 - z[0], c = bits.Add64(z[0], q0, 0) - z[1], c = bits.Add64(z[1], q1, c) - z[2], c = bits.Add64(z[2], q2, c) - z[3], _ = bits.Add64(z[3], q3, c) - } - return z -} - // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bn254/fp/element_ops_arm64.go b/ecc/bn254/fp/element_ops_arm64.go index c4a02f2a43..d09778d0d0 100644 --- a/ecc/bn254/fp/element_ops_arm64.go +++ b/ecc/bn254/fp/element_ops_arm64.go @@ -34,3 +34,11 @@ func (z *Element) Double(x *Element) *Element { double(z, x) return z } + +//go:noescape +func sub(res, x, y *Element) + +func (z *Element) Sub(x, y *Element) *Element { + sub(z, x, y) + return z +} diff --git a/ecc/bn254/fp/element_ops_arm64.s b/ecc/bn254/fp/element_ops_arm64.s index fad68ec28f..34606e5f07 100644 --- a/ecc/bn254/fp/element_ops_arm64.s +++ b/ecc/bn254/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 1365634154298741738 +// We include the hash to force the Go compiler to recompile: 16055964597816771835 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fr/element.go b/ecc/bn254/fr/element.go index 07b7b6a2b5..186b9e6a39 100644 --- a/ecc/bn254/fr/element.go +++ b/ecc/bn254/fr/element.go @@ -393,23 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Sub z = x - y (mod q) -func (z *Element) Sub(x, y *Element) *Element { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) - z[1], b = bits.Sub64(x[1], y[1], b) - z[2], b = bits.Sub64(x[2], y[2], b) - z[3], b = bits.Sub64(x[3], y[3], b) - if b != 0 { - var c uint64 - z[0], c = bits.Add64(z[0], q0, 0) - z[1], c = bits.Add64(z[1], q1, c) - z[2], c = bits.Add64(z[2], q2, c) - z[3], _ = bits.Add64(z[3], q3, c) - } - return z -} - // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bn254/fr/element_ops_arm64.go b/ecc/bn254/fr/element_ops_arm64.go index 604eec31d1..36cb58106c 100644 --- a/ecc/bn254/fr/element_ops_arm64.go +++ b/ecc/bn254/fr/element_ops_arm64.go @@ -34,3 +34,11 @@ func (z *Element) Double(x *Element) *Element { double(z, x) return z } + +//go:noescape +func sub(res, x, y *Element) + +func (z *Element) Sub(x, y *Element) *Element { + sub(z, x, y) + return z +} diff --git a/ecc/bn254/fr/element_ops_arm64.s b/ecc/bn254/fr/element_ops_arm64.s index fad68ec28f..34606e5f07 100644 --- a/ecc/bn254/fr/element_ops_arm64.s +++ b/ecc/bn254/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 1365634154298741738 +// We include the hash to force the Go compiler to recompile: 16055964597816771835 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bw6-761/fr/element.go b/ecc/bw6-761/fr/element.go index a90d04bb53..a887b71537 100644 --- a/ecc/bw6-761/fr/element.go +++ b/ecc/bw6-761/fr/element.go @@ -418,27 +418,6 @@ func (z *Element) fromMont() *Element { return z } -// Sub z = x - y (mod q) -func (z *Element) Sub(x, y *Element) *Element { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) - z[1], b = bits.Sub64(x[1], y[1], b) - z[2], b = bits.Sub64(x[2], y[2], b) - z[3], b = bits.Sub64(x[3], y[3], b) - z[4], b = bits.Sub64(x[4], y[4], b) - z[5], b = bits.Sub64(x[5], y[5], b) - if b != 0 { - var c uint64 - z[0], c = bits.Add64(z[0], q0, 0) - z[1], c = bits.Add64(z[1], q1, c) - z[2], c = bits.Add64(z[2], q2, c) - z[3], c = bits.Add64(z[3], q3, c) - z[4], c = bits.Add64(z[4], q4, c) - z[5], _ = bits.Add64(z[5], q5, c) - } - return z -} - // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/stark-curve/fp/element.go b/ecc/stark-curve/fp/element.go index fc22d531e1..dc0de49c67 100644 --- a/ecc/stark-curve/fp/element.go +++ b/ecc/stark-curve/fp/element.go @@ -393,23 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Sub z = x - y (mod q) -func (z *Element) Sub(x, y *Element) *Element { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) - z[1], b = bits.Sub64(x[1], y[1], b) - z[2], b = bits.Sub64(x[2], y[2], b) - z[3], b = bits.Sub64(x[3], y[3], b) - if b != 0 { - var c uint64 - z[0], c = bits.Add64(z[0], q0, 0) - z[1], c = bits.Add64(z[1], q1, c) - z[2], c = bits.Add64(z[2], q2, c) - z[3], _ = bits.Add64(z[3], q3, c) - } - return z -} - // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/stark-curve/fp/element_ops_arm64.go b/ecc/stark-curve/fp/element_ops_arm64.go index c4a02f2a43..d09778d0d0 100644 --- a/ecc/stark-curve/fp/element_ops_arm64.go +++ b/ecc/stark-curve/fp/element_ops_arm64.go @@ -34,3 +34,11 @@ func (z *Element) Double(x *Element) *Element { double(z, x) return z } + +//go:noescape +func sub(res, x, y *Element) + +func (z *Element) Sub(x, y *Element) *Element { + sub(z, x, y) + return z +} diff --git a/ecc/stark-curve/fp/element_ops_arm64.s b/ecc/stark-curve/fp/element_ops_arm64.s index fad68ec28f..34606e5f07 100644 --- a/ecc/stark-curve/fp/element_ops_arm64.s +++ b/ecc/stark-curve/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 1365634154298741738 +// We include the hash to force the Go compiler to recompile: 16055964597816771835 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fr/element.go b/ecc/stark-curve/fr/element.go index f393726632..65c70f9722 100644 --- a/ecc/stark-curve/fr/element.go +++ b/ecc/stark-curve/fr/element.go @@ -393,23 +393,6 @@ func (z *Element) fromMont() *Element { return z } -// Sub z = x - y (mod q) -func (z *Element) Sub(x, y *Element) *Element { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) - z[1], b = bits.Sub64(x[1], y[1], b) - z[2], b = bits.Sub64(x[2], y[2], b) - z[3], b = bits.Sub64(x[3], y[3], b) - if b != 0 { - var c uint64 - z[0], c = bits.Add64(z[0], q0, 0) - z[1], c = bits.Add64(z[1], q1, c) - z[2], c = bits.Add64(z[2], q2, c) - z[3], _ = bits.Add64(z[3], q3, c) - } - return z -} - // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/stark-curve/fr/element_ops_arm64.go b/ecc/stark-curve/fr/element_ops_arm64.go index 604eec31d1..36cb58106c 100644 --- a/ecc/stark-curve/fr/element_ops_arm64.go +++ b/ecc/stark-curve/fr/element_ops_arm64.go @@ -34,3 +34,11 @@ func (z *Element) Double(x *Element) *Element { double(z, x) return z } + +//go:noescape +func sub(res, x, y *Element) + +func (z *Element) Sub(x, y *Element) *Element { + sub(z, x, y) + return z +} diff --git a/ecc/stark-curve/fr/element_ops_arm64.s b/ecc/stark-curve/fr/element_ops_arm64.s index fad68ec28f..34606e5f07 100644 --- a/ecc/stark-curve/fr/element_ops_arm64.s +++ b/ecc/stark-curve/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 1365634154298741738 +// We include the hash to force the Go compiler to recompile: 16055964597816771835 #include "../../../field/asm/element_4w_arm64.s" diff --git a/field/asm/element_4w_arm64.s b/field/asm/element_4w_arm64.s index c00eeb970d..8101795d8c 100644 --- a/field/asm/element_4w_arm64.s +++ b/field/asm/element_4w_arm64.s @@ -7,8 +7,8 @@ TEXT ·add(SB), NOSPLIT, $0-24 LDP x+8(FP), (R12, R13) LDP 0(R12), (R8, R9) - LDP 0(R13), (R4, R5) LDP 16(R12), (R10, R11) + LDP 0(R13), (R4, R5) LDP 16(R13), (R6, R7) ADDS R8, R4, R4 ADCS R9, R5, R5 @@ -61,3 +61,33 @@ TEXT ·double(SB), NOSPLIT, $0-16 STP (R2, R3), 0(R1) STP (R4, R5), 16(R1) RET + +// sub(res, x, y *Element) +TEXT ·sub(SB), NOSPLIT, $0-24 + LDP x+8(FP), (R12, R13) + LDP 0(R12), (R4, R5) + LDP 16(R12), (R6, R7) + LDP 0(R13), (R0, R1) + LDP 16(R13), (R2, R3) + SUBS R0, R4, R0 + SBCS R1, R5, R1 + SBCS R2, R6, R2 + SBCS R3, R7, R3 + + // load modulus and select + LDP ·qElement+0(SB), (R8, R9) + LDP ·qElement+16(SB), (R10, R11) + CSEL CS, ZR, R8, R8 + CSEL CS, ZR, R9, R9 + CSEL CS, ZR, R10, R10 + CSEL CS, ZR, R11, R11 + + // add q if underflow, 0 if not + ADDS R0, R8, R0 + ADCS R1, R9, R1 + ADCS R2, R10, R2 + ADCS R3, R11, R3 + MOVD res+0(FP), R14 + STP (R0, R1), 0(R14) + STP (R2, R3), 16(R14) + RET diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index d9cf214c2b..edfedc5f14 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -109,7 +109,7 @@ func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { f.generateAdd() f.generateDouble() - // f.generateSub() + f.generateSub() // f.generateNeg() return nil diff --git a/field/generator/asm/arm64/element_ops.go b/field/generator/asm/arm64/element_ops.go index 0a5923f7fb..205c5dabba 100644 --- a/field/generator/asm/arm64/element_ops.go +++ b/field/generator/asm/arm64/element_ops.go @@ -33,10 +33,8 @@ func (f *FFArm64) generateAdd() { f.LDP("x+8(FP)", xPtr, yPtr) - for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(xPtr.At(i), x[i], x[i+1]) - f.LDP(yPtr.At(i), z[i], z[i+1]) - } + f.load(xPtr, x) + f.load(yPtr, z) f.ADDS(x[0], z[0], z[0]) for i := 1; i < f.NbWords; i++ { @@ -48,7 +46,7 @@ func (f *FFArm64) generateAdd() { f.Comment("store") f.MOVD("res+0(FP)", zPtr) - f.storeVector(z, zPtr) + f.store(z, zPtr) f.RET() @@ -67,9 +65,7 @@ func (f *FFArm64) generateDouble() { f.LDP("res+0(FP)", zPtr, xPtr) - for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(xPtr.At(i), z[i], z[i+1]) - } + f.load(xPtr, z) f.ADDS(z[0], z[0], z[0]) for i := 1; i < f.NbWords; i++ { @@ -78,7 +74,7 @@ func (f *FFArm64) generateDouble() { f.reduce(z, t) - f.storeVector(z, zPtr) + f.store(z, zPtr) f.RET() @@ -93,56 +89,40 @@ func (f *FFArm64) generateSub() { // registers z := registers.PopN(f.NbWords) + x := registers.PopN(f.NbWords) + t := registers.PopN(f.NbWords) xPtr := registers.Pop() yPtr := registers.Pop() - ops := registers.PopN(2) + zPtr := registers.Pop() f.LDP("x+8(FP)", xPtr, yPtr) - f.Comment("load operands and subtract mod 2^r") - - op0 := f.SUBS - for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(xPtr.At(i), z[i], ops[0]) - f.LDP(yPtr.At(i), z[i+1], ops[1]) - op0(z[i+1], z[i], z[i]) - op0 = f.SBCS + f.load(xPtr, x) + f.load(yPtr, z) - f.SBCS(ops[1], ops[0], z[i+1]) + f.SUBS(z[0], x[0], z[0]) + for i := 1; i < f.NbWords; i++ { + f.SBCS(z[i], x[i], z[i]) } - registers.Push(xPtr, yPtr) - registers.Push(ops...) - f.Comment("load modulus and select") - t := registers.PopN(f.NbWords) - zero := registers.Pop() - f.MOVD(0, zero) + zero := arm64.Register("ZR") for i := 0; i < f.NbWords-1; i += 2 { f.LDP(f.qAt(i), t[i], t[i+1]) - - f.CSEL("CS", zero, t[i], t[i]) - f.CSEL("CS", zero, t[i+1], t[i+1]) } - - registers.Push(zero) - - f.Comment("augment (or not)") - - op0 = f.ADDS for i := 0; i < f.NbWords; i++ { - op0(z[i], t[i], z[i]) - op0 = f.ADCS + f.CSEL("CS", zero, t[i], t[i]) + } + f.Comment("add q if underflow, 0 if not") + f.ADDS(z[0], t[0], z[0]) + for i := 1; i < f.NbWords; i++ { + f.ADCS(z[i], t[i], z[i]) } - registers.Push(t...) - - f.Comment("store") - zPtr := registers.Pop() f.MOVD("res+0(FP)", zPtr) - f.storeVector(z, zPtr) + f.store(z, zPtr) f.RET() @@ -187,7 +167,7 @@ func (f *FFArm64) generateNeg() { } f.Comment("store") - f.storeVector(z, zPtr) + f.store(z, zPtr) f.RET() @@ -215,7 +195,13 @@ func (f *FFArm64) reduce(z, t []arm64.Register) { } } -func (f *FFArm64) storeVector(z []arm64.Register, zPtr arm64.Register) { +func (f *FFArm64) load(zPtr arm64.Register, z []arm64.Register) { + for i := 0; i < f.NbWords-1; i += 2 { + f.LDP(zPtr.At(i), z[i], z[i+1]) + } +} + +func (f *FFArm64) store(z []arm64.Register, zPtr arm64.Register) { for i := 0; i < f.NbWords-1; i += 2 { f.STP(z[i], z[i+1], zPtr.At(i)) } diff --git a/field/generator/internal/templates/element/base.go b/field/generator/internal/templates/element/base.go index 8cd623514a..a31e4e8611 100644 --- a/field/generator/internal/templates/element/base.go +++ b/field/generator/internal/templates/element/base.go @@ -474,7 +474,6 @@ func (z *{{.ElementName}}) Double( x *{{.ElementName}}) *{{.ElementName}} { return z } -{{- end}} // Sub z = x - y (mod q) @@ -502,6 +501,9 @@ func (z *{{.ElementName}}) Sub( x, y *{{.ElementName}}) *{{.ElementName}} { return z } +{{- end}} + + // Neg z = q - x func (z *{{.ElementName}}) Neg( x *{{.ElementName}}) *{{.ElementName}} { if x.IsZero() { diff --git a/field/generator/internal/templates/element/ops_asm.go b/field/generator/internal/templates/element/ops_asm.go index 896b19b857..5f9d7a8245 100644 --- a/field/generator/internal/templates/element/ops_asm.go +++ b/field/generator/internal/templates/element/ops_asm.go @@ -214,5 +214,14 @@ func (z *{{.ElementName}}) Double(x *{{.ElementName}}) *{{.ElementName}} { double(z,x) return z } + +//go:noescape +func sub(res,x,y *{{.ElementName}}) + +func (z *{{.ElementName}}) Sub(x, y *{{.ElementName}}) *{{.ElementName}} { + sub(z,x,y) + return z +} + {{end}} ` From 637a68af090a5ea7c2584b64db3e58474a574660 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Wed, 16 Oct 2024 17:25:28 +0000 Subject: [PATCH 09/74] feat: added butterfly asm, experiment --- ecc/bls12-377/fp/element_ops_purego.go | 8 --- ecc/bls12-377/fr/element_ops_arm64.go | 3 + ecc/bls12-377/fr/element_ops_arm64.s | 2 +- ecc/bls12-377/fr/element_ops_purego.go | 8 --- ecc/bls12-381/fp/element_ops_purego.go | 8 --- ecc/bls12-381/fr/element_ops_arm64.go | 3 + ecc/bls12-381/fr/element_ops_arm64.s | 2 +- ecc/bls12-381/fr/element_ops_purego.go | 8 --- ecc/bls24-315/fp/element_ops_purego.go | 1 + ecc/bls24-315/fr/element_ops_arm64.go | 3 + ecc/bls24-315/fr/element_ops_arm64.s | 2 +- ecc/bls24-315/fr/element_ops_purego.go | 8 --- ecc/bls24-317/fp/element_ops_purego.go | 1 + ecc/bls24-317/fr/element_ops_arm64.go | 3 + ecc/bls24-317/fr/element_ops_arm64.s | 2 +- ecc/bls24-317/fr/element_ops_purego.go | 8 --- ecc/bn254/fp/element_ops_arm64.go | 3 + ecc/bn254/fp/element_ops_arm64.s | 2 +- ecc/bn254/fp/element_ops_purego.go | 8 --- ecc/bn254/fr/element_ops_arm64.go | 3 + ecc/bn254/fr/element_ops_arm64.s | 2 +- ecc/bn254/fr/element_ops_purego.go | 8 --- ecc/bw6-633/fp/element_ops_purego.go | 1 + ecc/bw6-633/fr/element_ops_purego.go | 1 + ecc/bw6-761/fp/element_ops_purego.go | 1 + ecc/bw6-761/fr/element_ops_purego.go | 8 --- ecc/secp256k1/fp/element_ops_purego.go | 1 + ecc/secp256k1/fr/element_ops_purego.go | 1 + ecc/stark-curve/fp/element_ops_arm64.go | 3 + ecc/stark-curve/fp/element_ops_arm64.s | 2 +- ecc/stark-curve/fp/element_ops_purego.go | 8 --- ecc/stark-curve/fr/element_ops_arm64.go | 3 + ecc/stark-curve/fr/element_ops_arm64.s | 2 +- ecc/stark-curve/fr/element_ops_purego.go | 8 --- field/asm/element_4w_arm64.s | 53 ++++++++++++++ field/generator/asm/arm64/build.go | 2 +- field/generator/asm/arm64/element_ops.go | 69 +++++++++++-------- .../internal/templates/element/base.go | 2 +- .../internal/templates/element/ops_asm.go | 3 + .../internal/templates/element/ops_purego.go | 3 + field/goldilocks/element_ops_purego.go | 1 + 41 files changed, 143 insertions(+), 125 deletions(-) diff --git a/ecc/bls12-377/fp/element_ops_purego.go b/ecc/bls12-377/fp/element_ops_purego.go index 072fb87c01..c2d525e2bf 100644 --- a/ecc/bls12-377/fp/element_ops_purego.go +++ b/ecc/bls12-377/fp/element_ops_purego.go @@ -46,14 +46,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } diff --git a/ecc/bls12-377/fr/element_ops_arm64.go b/ecc/bls12-377/fr/element_ops_arm64.go index 36cb58106c..00cbb30388 100644 --- a/ecc/bls12-377/fr/element_ops_arm64.go +++ b/ecc/bls12-377/fr/element_ops_arm64.go @@ -42,3 +42,6 @@ func (z *Element) Sub(x, y *Element) *Element { sub(z, x, y) return z } + +//go:noescape +func Butterfly(a, b *Element) diff --git a/ecc/bls12-377/fr/element_ops_arm64.s b/ecc/bls12-377/fr/element_ops_arm64.s index 34606e5f07..1ae94dfd6a 100644 --- a/ecc/bls12-377/fr/element_ops_arm64.s +++ b/ecc/bls12-377/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 16055964597816771835 +// We include the hash to force the Go compiler to recompile: 2806534830277291526 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls12-377/fr/element_ops_purego.go b/ecc/bls12-377/fr/element_ops_purego.go index f107066c79..583a74f359 100644 --- a/ecc/bls12-377/fr/element_ops_purego.go +++ b/ecc/bls12-377/fr/element_ops_purego.go @@ -44,14 +44,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } diff --git a/ecc/bls12-381/fp/element_ops_purego.go b/ecc/bls12-381/fp/element_ops_purego.go index ee3f7e7408..ecbfdb23cf 100644 --- a/ecc/bls12-381/fp/element_ops_purego.go +++ b/ecc/bls12-381/fp/element_ops_purego.go @@ -46,14 +46,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } diff --git a/ecc/bls12-381/fr/element_ops_arm64.go b/ecc/bls12-381/fr/element_ops_arm64.go index 36cb58106c..00cbb30388 100644 --- a/ecc/bls12-381/fr/element_ops_arm64.go +++ b/ecc/bls12-381/fr/element_ops_arm64.go @@ -42,3 +42,6 @@ func (z *Element) Sub(x, y *Element) *Element { sub(z, x, y) return z } + +//go:noescape +func Butterfly(a, b *Element) diff --git a/ecc/bls12-381/fr/element_ops_arm64.s b/ecc/bls12-381/fr/element_ops_arm64.s index 34606e5f07..1ae94dfd6a 100644 --- a/ecc/bls12-381/fr/element_ops_arm64.s +++ b/ecc/bls12-381/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 16055964597816771835 +// We include the hash to force the Go compiler to recompile: 2806534830277291526 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls12-381/fr/element_ops_purego.go b/ecc/bls12-381/fr/element_ops_purego.go index 8c10496433..7860a46621 100644 --- a/ecc/bls12-381/fr/element_ops_purego.go +++ b/ecc/bls12-381/fr/element_ops_purego.go @@ -44,14 +44,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } diff --git a/ecc/bls24-315/fp/element_ops_purego.go b/ecc/bls24-315/fp/element_ops_purego.go index 4796fc3c5b..348a99f991 100644 --- a/ecc/bls24-315/fp/element_ops_purego.go +++ b/ecc/bls24-315/fp/element_ops_purego.go @@ -45,6 +45,7 @@ func MulBy13(x *Element) { x.Mul(x, &y) } +// TODO @gbotrel fixme. // Butterfly sets // // a = a + b (mod q) diff --git a/ecc/bls24-315/fr/element_ops_arm64.go b/ecc/bls24-315/fr/element_ops_arm64.go index 36cb58106c..00cbb30388 100644 --- a/ecc/bls24-315/fr/element_ops_arm64.go +++ b/ecc/bls24-315/fr/element_ops_arm64.go @@ -42,3 +42,6 @@ func (z *Element) Sub(x, y *Element) *Element { sub(z, x, y) return z } + +//go:noescape +func Butterfly(a, b *Element) diff --git a/ecc/bls24-315/fr/element_ops_arm64.s b/ecc/bls24-315/fr/element_ops_arm64.s index 34606e5f07..1ae94dfd6a 100644 --- a/ecc/bls24-315/fr/element_ops_arm64.s +++ b/ecc/bls24-315/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 16055964597816771835 +// We include the hash to force the Go compiler to recompile: 2806534830277291526 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-315/fr/element_ops_purego.go b/ecc/bls24-315/fr/element_ops_purego.go index e7a8817f01..4eb7e31b87 100644 --- a/ecc/bls24-315/fr/element_ops_purego.go +++ b/ecc/bls24-315/fr/element_ops_purego.go @@ -44,14 +44,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } diff --git a/ecc/bls24-317/fp/element_ops_purego.go b/ecc/bls24-317/fp/element_ops_purego.go index 9f72e6f84b..6663c98bbf 100644 --- a/ecc/bls24-317/fp/element_ops_purego.go +++ b/ecc/bls24-317/fp/element_ops_purego.go @@ -45,6 +45,7 @@ func MulBy13(x *Element) { x.Mul(x, &y) } +// TODO @gbotrel fixme. // Butterfly sets // // a = a + b (mod q) diff --git a/ecc/bls24-317/fr/element_ops_arm64.go b/ecc/bls24-317/fr/element_ops_arm64.go index 36cb58106c..00cbb30388 100644 --- a/ecc/bls24-317/fr/element_ops_arm64.go +++ b/ecc/bls24-317/fr/element_ops_arm64.go @@ -42,3 +42,6 @@ func (z *Element) Sub(x, y *Element) *Element { sub(z, x, y) return z } + +//go:noescape +func Butterfly(a, b *Element) diff --git a/ecc/bls24-317/fr/element_ops_arm64.s b/ecc/bls24-317/fr/element_ops_arm64.s index 34606e5f07..1ae94dfd6a 100644 --- a/ecc/bls24-317/fr/element_ops_arm64.s +++ b/ecc/bls24-317/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 16055964597816771835 +// We include the hash to force the Go compiler to recompile: 2806534830277291526 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-317/fr/element_ops_purego.go b/ecc/bls24-317/fr/element_ops_purego.go index 7afd9cc8df..c0644ee03d 100644 --- a/ecc/bls24-317/fr/element_ops_purego.go +++ b/ecc/bls24-317/fr/element_ops_purego.go @@ -44,14 +44,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } diff --git a/ecc/bn254/fp/element_ops_arm64.go b/ecc/bn254/fp/element_ops_arm64.go index d09778d0d0..0133155d8e 100644 --- a/ecc/bn254/fp/element_ops_arm64.go +++ b/ecc/bn254/fp/element_ops_arm64.go @@ -42,3 +42,6 @@ func (z *Element) Sub(x, y *Element) *Element { sub(z, x, y) return z } + +//go:noescape +func Butterfly(a, b *Element) diff --git a/ecc/bn254/fp/element_ops_arm64.s b/ecc/bn254/fp/element_ops_arm64.s index 34606e5f07..1ae94dfd6a 100644 --- a/ecc/bn254/fp/element_ops_arm64.s +++ b/ecc/bn254/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 16055964597816771835 +// We include the hash to force the Go compiler to recompile: 2806534830277291526 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fp/element_ops_purego.go b/ecc/bn254/fp/element_ops_purego.go index 454376da57..1b404b5589 100644 --- a/ecc/bn254/fp/element_ops_purego.go +++ b/ecc/bn254/fp/element_ops_purego.go @@ -44,14 +44,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } diff --git a/ecc/bn254/fr/element_ops_arm64.go b/ecc/bn254/fr/element_ops_arm64.go index 36cb58106c..00cbb30388 100644 --- a/ecc/bn254/fr/element_ops_arm64.go +++ b/ecc/bn254/fr/element_ops_arm64.go @@ -42,3 +42,6 @@ func (z *Element) Sub(x, y *Element) *Element { sub(z, x, y) return z } + +//go:noescape +func Butterfly(a, b *Element) diff --git a/ecc/bn254/fr/element_ops_arm64.s b/ecc/bn254/fr/element_ops_arm64.s index 34606e5f07..1ae94dfd6a 100644 --- a/ecc/bn254/fr/element_ops_arm64.s +++ b/ecc/bn254/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 16055964597816771835 +// We include the hash to force the Go compiler to recompile: 2806534830277291526 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fr/element_ops_purego.go b/ecc/bn254/fr/element_ops_purego.go index 4ea220c185..721f5e7f39 100644 --- a/ecc/bn254/fr/element_ops_purego.go +++ b/ecc/bn254/fr/element_ops_purego.go @@ -44,14 +44,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } diff --git a/ecc/bw6-633/fp/element_ops_purego.go b/ecc/bw6-633/fp/element_ops_purego.go index 3b5d489a30..56c5798d50 100644 --- a/ecc/bw6-633/fp/element_ops_purego.go +++ b/ecc/bw6-633/fp/element_ops_purego.go @@ -50,6 +50,7 @@ func MulBy13(x *Element) { x.Mul(x, &y) } +// TODO @gbotrel fixme. // Butterfly sets // // a = a + b (mod q) diff --git a/ecc/bw6-633/fr/element_ops_purego.go b/ecc/bw6-633/fr/element_ops_purego.go index 4a7cdbfe46..aa5e785c1f 100644 --- a/ecc/bw6-633/fr/element_ops_purego.go +++ b/ecc/bw6-633/fr/element_ops_purego.go @@ -45,6 +45,7 @@ func MulBy13(x *Element) { x.Mul(x, &y) } +// TODO @gbotrel fixme. // Butterfly sets // // a = a + b (mod q) diff --git a/ecc/bw6-761/fp/element_ops_purego.go b/ecc/bw6-761/fp/element_ops_purego.go index 59d6d1d523..9ff53651dc 100644 --- a/ecc/bw6-761/fp/element_ops_purego.go +++ b/ecc/bw6-761/fp/element_ops_purego.go @@ -52,6 +52,7 @@ func MulBy13(x *Element) { x.Mul(x, &y) } +// TODO @gbotrel fixme. // Butterfly sets // // a = a + b (mod q) diff --git a/ecc/bw6-761/fr/element_ops_purego.go b/ecc/bw6-761/fr/element_ops_purego.go index bdf76428de..d9bc7f95a3 100644 --- a/ecc/bw6-761/fr/element_ops_purego.go +++ b/ecc/bw6-761/fr/element_ops_purego.go @@ -46,14 +46,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } diff --git a/ecc/secp256k1/fp/element_ops_purego.go b/ecc/secp256k1/fp/element_ops_purego.go index f53ffa325b..9059e82e31 100644 --- a/ecc/secp256k1/fp/element_ops_purego.go +++ b/ecc/secp256k1/fp/element_ops_purego.go @@ -41,6 +41,7 @@ func MulBy13(x *Element) { x.Mul(x, &y) } +// TODO @gbotrel fixme. // Butterfly sets // // a = a + b (mod q) diff --git a/ecc/secp256k1/fr/element_ops_purego.go b/ecc/secp256k1/fr/element_ops_purego.go index ef83ea20a5..eb7f9781e7 100644 --- a/ecc/secp256k1/fr/element_ops_purego.go +++ b/ecc/secp256k1/fr/element_ops_purego.go @@ -41,6 +41,7 @@ func MulBy13(x *Element) { x.Mul(x, &y) } +// TODO @gbotrel fixme. // Butterfly sets // // a = a + b (mod q) diff --git a/ecc/stark-curve/fp/element_ops_arm64.go b/ecc/stark-curve/fp/element_ops_arm64.go index d09778d0d0..0133155d8e 100644 --- a/ecc/stark-curve/fp/element_ops_arm64.go +++ b/ecc/stark-curve/fp/element_ops_arm64.go @@ -42,3 +42,6 @@ func (z *Element) Sub(x, y *Element) *Element { sub(z, x, y) return z } + +//go:noescape +func Butterfly(a, b *Element) diff --git a/ecc/stark-curve/fp/element_ops_arm64.s b/ecc/stark-curve/fp/element_ops_arm64.s index 34606e5f07..1ae94dfd6a 100644 --- a/ecc/stark-curve/fp/element_ops_arm64.s +++ b/ecc/stark-curve/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 16055964597816771835 +// We include the hash to force the Go compiler to recompile: 2806534830277291526 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fp/element_ops_purego.go b/ecc/stark-curve/fp/element_ops_purego.go index 19cb3649be..3b2b30110f 100644 --- a/ecc/stark-curve/fp/element_ops_purego.go +++ b/ecc/stark-curve/fp/element_ops_purego.go @@ -44,14 +44,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } diff --git a/ecc/stark-curve/fr/element_ops_arm64.go b/ecc/stark-curve/fr/element_ops_arm64.go index 36cb58106c..00cbb30388 100644 --- a/ecc/stark-curve/fr/element_ops_arm64.go +++ b/ecc/stark-curve/fr/element_ops_arm64.go @@ -42,3 +42,6 @@ func (z *Element) Sub(x, y *Element) *Element { sub(z, x, y) return z } + +//go:noescape +func Butterfly(a, b *Element) diff --git a/ecc/stark-curve/fr/element_ops_arm64.s b/ecc/stark-curve/fr/element_ops_arm64.s index 34606e5f07..1ae94dfd6a 100644 --- a/ecc/stark-curve/fr/element_ops_arm64.s +++ b/ecc/stark-curve/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 16055964597816771835 +// We include the hash to force the Go compiler to recompile: 2806534830277291526 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fr/element_ops_purego.go b/ecc/stark-curve/fr/element_ops_purego.go index 2d0db69153..2a57b19d86 100644 --- a/ecc/stark-curve/fr/element_ops_purego.go +++ b/ecc/stark-curve/fr/element_ops_purego.go @@ -44,14 +44,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } diff --git a/field/asm/element_4w_arm64.s b/field/asm/element_4w_arm64.s index 8101795d8c..76a18f25a5 100644 --- a/field/asm/element_4w_arm64.s +++ b/field/asm/element_4w_arm64.s @@ -91,3 +91,56 @@ TEXT ·sub(SB), NOSPLIT, $0-24 STP (R0, R1), 0(R14) STP (R2, R3), 16(R14) RET + +// butterfly(x, y *Element) +TEXT ·Butterfly(SB), NOSPLIT, $0-16 + LDP x+0(FP), (R16, R17) + LDP 0(R16), (R0, R1) + LDP 16(R16), (R2, R3) + LDP 0(R17), (R4, R5) + LDP 16(R17), (R6, R7) + ADDS R0, R4, R8 + ADCS R1, R5, R9 + ADCS R2, R6, R10 + ADCS R3, R7, R11 + + // load modulus and subtract + LDP ·qElement+0(SB), (R12, R13) + LDP ·qElement+16(SB), (R14, R15) + SUBS R12, R8, R12 + SBCS R13, R9, R13 + SBCS R14, R10, R14 + SBCS R15, R11, R15 + + // reduce if necessary + CSEL CS, R12, R8, R8 + CSEL CS, R13, R9, R9 + CSEL CS, R14, R10, R10 + CSEL CS, R15, R11, R11 + + // store + STP (R8, R9), 0(R16) + STP (R10, R11), 16(R16) + SUBS R4, R0, R4 + SBCS R5, R1, R5 + SBCS R6, R2, R6 + SBCS R7, R3, R7 + + // load modulus and select + LDP ·qElement+0(SB), (R12, R13) + LDP ·qElement+16(SB), (R14, R15) + CSEL CS, ZR, R12, R12 + CSEL CS, ZR, R13, R13 + CSEL CS, ZR, R14, R14 + CSEL CS, ZR, R15, R15 + + // add q if underflow, 0 if not + ADDS R4, R12, R4 + ADCS R5, R13, R5 + ADCS R6, R14, R6 + ADCS R7, R15, R7 + + // store + STP (R4, R5), 0(R17) + STP (R6, R7), 16(R17) + RET diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index edfedc5f14..6824a06edb 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -110,7 +110,7 @@ func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { f.generateAdd() f.generateDouble() f.generateSub() - // f.generateNeg() + f.generateButterfly() return nil } diff --git a/field/generator/asm/arm64/element_ops.go b/field/generator/asm/arm64/element_ops.go index 205c5dabba..d75d21e2fa 100644 --- a/field/generator/asm/arm64/element_ops.go +++ b/field/generator/asm/arm64/element_ops.go @@ -128,49 +128,64 @@ func (f *FFArm64) generateSub() { } -func (f *FFArm64) generateNeg() { - f.Comment("neg(res, x *Element)") - registers := f.FnHeader("neg", 0, 16) +func (f *FFArm64) generateButterfly() { + f.Comment("butterfly(x, y *Element)") + registers := f.FnHeader("Butterfly", 0, 16) defer f.AssertCleanStack(0, 0) - + // Butterfly sets + // a = a + b (mod q) + // b = a - b (mod q) // registers - z := registers.PopN(f.NbWords) - xPtr := registers.Pop() - zPtr := registers.Pop() - ops := registers.PopN(2) - xNotZero := registers.Pop() + a := registers.PopN(f.NbWords) + b := registers.PopN(f.NbWords) + aRes := registers.PopN(f.NbWords) + t := registers.PopN(f.NbWords) + aPtr := registers.Pop() + bPtr := registers.Pop() - f.LDP("res+0(FP)", zPtr, xPtr) - f.Comment("load operands and subtract") + f.LDP("x+0(FP)", aPtr, bPtr) + f.load(aPtr, a) + f.load(bPtr, b) - f.MOVD(0, xNotZero) - op0 := f.SUBS - for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(xPtr.At(i), z[i], z[i+1]) - f.LDP(f.qAt(i), ops[0], ops[1]) + f.ADDS(a[0], b[0], aRes[0]) + for i := 1; i < f.NbWords; i++ { + f.ADCS(a[i], b[i], aRes[i]) + } + + f.reduce(aRes, t) - f.ORR(z[i], xNotZero, xNotZero, "has x been 0 so far?") - f.ORR(z[i+1], xNotZero, xNotZero) + f.Comment("store") + + f.store(aRes, aPtr) - op0(z[i], ops[0], z[i]) - op0 = f.SBCS + bRes := b - f.SBCS(z[i+1], ops[1], z[i+1]) + f.SUBS(b[0], a[0], bRes[0]) + for i := 1; i < f.NbWords; i++ { + f.SBCS(b[i], a[i], bRes[i]) } - registers.Push(xPtr) - registers.Push(ops...) + f.Comment("load modulus and select") + + zero := arm64.Register("ZR") - f.TST(-1, xNotZero) + for i := 0; i < f.NbWords-1; i += 2 { + f.LDP(f.qAt(i), t[i], t[i+1]) + } for i := 0; i < f.NbWords; i++ { - f.CSEL("EQ", xNotZero, z[i], z[i]) + f.CSEL("CS", zero, t[i], t[i]) + } + f.Comment("add q if underflow, 0 if not") + f.ADDS(bRes[0], t[0], bRes[0]) + for i := 1; i < f.NbWords; i++ { + f.ADCS(bRes[i], t[i], bRes[i]) } f.Comment("store") - f.store(z, zPtr) - f.RET() + f.store(bRes, bPtr) + f.RET() } func (f *FFArm64) reduce(z, t []arm64.Register) { diff --git a/field/generator/internal/templates/element/base.go b/field/generator/internal/templates/element/base.go index a31e4e8611..65ccb7a68b 100644 --- a/field/generator/internal/templates/element/base.go +++ b/field/generator/internal/templates/element/base.go @@ -503,7 +503,6 @@ func (z *{{.ElementName}}) Sub( x, y *{{.ElementName}}) *{{.ElementName}} { {{- end}} - // Neg z = q - x func (z *{{.ElementName}}) Neg( x *{{.ElementName}}) *{{.ElementName}} { if x.IsZero() { @@ -526,6 +525,7 @@ func (z *{{.ElementName}}) Neg( x *{{.ElementName}}) *{{.ElementName}} { return z } + // Select is a constant-time conditional move. // If c=0, z = x0. Else z = x1 func (z *{{.ElementName}}) Select(c int, x0 *{{.ElementName}}, x1 *{{.ElementName}}) *{{.ElementName}} { diff --git a/field/generator/internal/templates/element/ops_asm.go b/field/generator/internal/templates/element/ops_asm.go index 5f9d7a8245..5f596cfcf4 100644 --- a/field/generator/internal/templates/element/ops_asm.go +++ b/field/generator/internal/templates/element/ops_asm.go @@ -223,5 +223,8 @@ func (z *{{.ElementName}}) Sub(x, y *{{.ElementName}}) *{{.ElementName}} { return z } +//go:noescape +func Butterfly(a, b *{{.ElementName}}) + {{end}} ` diff --git a/field/generator/internal/templates/element/ops_purego.go b/field/generator/internal/templates/element/ops_purego.go index fe3cead3c4..cce771aa1b 100644 --- a/field/generator/internal/templates/element/ops_purego.go +++ b/field/generator/internal/templates/element/ops_purego.go @@ -34,12 +34,15 @@ func MulBy{{$i}}(x *{{$.ElementName}}) { {{- end}} +{{- if not .ASMArm}} +// TODO @gbotrel fixme. // Butterfly sets // a = a + b (mod q) // b = a - b (mod q) func Butterfly(a, b *{{.ElementName}}) { _butterflyGeneric(a, b) } +{{- end}} func fromMont(z *{{.ElementName}} ) { diff --git a/field/goldilocks/element_ops_purego.go b/field/goldilocks/element_ops_purego.go index 5dc51ea242..1985eeb161 100644 --- a/field/goldilocks/element_ops_purego.go +++ b/field/goldilocks/element_ops_purego.go @@ -39,6 +39,7 @@ func MulBy13(x *Element) { x.Mul(x, &y) } +// TODO @gbotrel fixme. // Butterfly sets // // a = a + b (mod q) From 222036aa6069f671bdd3b3c955ab055c95368429 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Wed, 16 Oct 2024 20:10:23 +0000 Subject: [PATCH 10/74] checkpoint --- ecc/bls12-377/fp/element.go | 69 +++++++ ecc/bls12-377/fp/element_ops_purego.go | 9 + ecc/bls12-377/fr/element_ops_arm64.go | 8 + ecc/bls12-377/fr/element_ops_arm64.s | 2 +- ecc/bls12-377/fr/element_ops_purego.go | 172 ------------------ ecc/bls12-381/fp/element.go | 69 +++++++ ecc/bls12-381/fp/element_ops_purego.go | 9 + ecc/bls12-381/fr/element_ops_arm64.go | 8 + ecc/bls12-381/fr/element_ops_arm64.s | 2 +- ecc/bls12-381/fr/element_ops_purego.go | 172 ------------------ ecc/bls24-315/fr/element_ops_arm64.go | 8 + ecc/bls24-315/fr/element_ops_arm64.s | 2 +- ecc/bls24-315/fr/element_ops_purego.go | 172 ------------------ ecc/bls24-317/fr/element_ops_arm64.go | 8 + ecc/bls24-317/fr/element_ops_arm64.s | 2 +- ecc/bls24-317/fr/element_ops_purego.go | 172 ------------------ ecc/bn254/fp/element_ops_arm64.go | 8 + ecc/bn254/fp/element_ops_arm64.s | 2 +- ecc/bn254/fp/element_ops_purego.go | 172 ------------------ ecc/bn254/fr/element_ops_arm64.go | 8 + ecc/bn254/fr/element_ops_arm64.s | 2 +- ecc/bn254/fr/element_ops_purego.go | 172 ------------------ ecc/bw6-761/fr/element.go | 69 +++++++ ecc/bw6-761/fr/element_ops_purego.go | 9 + ecc/stark-curve/fp/element_ops_arm64.go | 8 + ecc/stark-curve/fp/element_ops_arm64.s | 2 +- ecc/stark-curve/fp/element_ops_purego.go | 172 ------------------ ecc/stark-curve/fr/element_ops_arm64.go | 8 + ecc/stark-curve/fr/element_ops_arm64.s | 2 +- ecc/stark-curve/fr/element_ops_purego.go | 172 ------------------ field/asm/element_4w_arm64.s | 88 +++++++++ field/generator/asm/arm64/build.go | 2 + field/generator/asm/arm64/element_ops.go | 152 +++++++++++++++- field/generator/config/field_config.go | 2 +- .../internal/templates/element/ops_asm.go | 8 + .../internal/templates/element/ops_purego.go | 3 + 36 files changed, 554 insertions(+), 1391 deletions(-) diff --git a/ecc/bls12-377/fp/element.go b/ecc/bls12-377/fp/element.go index 408b94653d..393f45744d 100644 --- a/ecc/bls12-377/fp/element.go +++ b/ecc/bls12-377/fp/element.go @@ -418,6 +418,75 @@ func (z *Element) fromMont() *Element { return z } +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], carry = bits.Add64(x[3], y[3], carry) + z[4], carry = bits.Add64(x[4], y[4], carry) + z[5], _ = bits.Add64(x[5], y[5], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], carry = bits.Add64(x[3], x[3], carry) + z[4], carry = bits.Add64(x[4], x[4], carry) + z[5], _ = bits.Add64(x[5], x[5], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + z[4], b = bits.Sub64(x[4], y[4], b) + z[5], b = bits.Sub64(x[5], y[5], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + z[1], c = bits.Add64(z[1], q1, c) + z[2], c = bits.Add64(z[2], q2, c) + z[3], c = bits.Add64(z[3], q3, c) + z[4], c = bits.Add64(z[4], q4, c) + z[5], _ = bits.Add64(z[5], q5, c) + } + return z +} + // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bls12-377/fp/element_ops_purego.go b/ecc/bls12-377/fp/element_ops_purego.go index c2d525e2bf..5b720d9fde 100644 --- a/ecc/bls12-377/fp/element_ops_purego.go +++ b/ecc/bls12-377/fp/element_ops_purego.go @@ -46,6 +46,15 @@ func MulBy13(x *Element) { x.Mul(x, &y) } +// TODO @gbotrel fixme. +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + func fromMont(z *Element) { _fromMontGeneric(z) } diff --git a/ecc/bls12-377/fr/element_ops_arm64.go b/ecc/bls12-377/fr/element_ops_arm64.go index 00cbb30388..9a57c7ca47 100644 --- a/ecc/bls12-377/fr/element_ops_arm64.go +++ b/ecc/bls12-377/fr/element_ops_arm64.go @@ -45,3 +45,11 @@ func (z *Element) Sub(x, y *Element) *Element { //go:noescape func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} diff --git a/ecc/bls12-377/fr/element_ops_arm64.s b/ecc/bls12-377/fr/element_ops_arm64.s index 1ae94dfd6a..5843be3b0a 100644 --- a/ecc/bls12-377/fr/element_ops_arm64.s +++ b/ecc/bls12-377/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 2806534830277291526 +// We include the hash to force the Go compiler to recompile: 9556532307245233012 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls12-377/fr/element_ops_purego.go b/ecc/bls12-377/fr/element_ops_purego.go index 583a74f359..6108a57091 100644 --- a/ecc/bls12-377/fr/element_ops_purego.go +++ b/ecc/bls12-377/fr/element_ops_purego.go @@ -89,178 +89,6 @@ func (vector *Vector) Mul(a, b Vector) { mulVecGeneric(*vector, a, b) } -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - var t0, t1, t2, t3 uint64 - var u0, u1, u2, u3 uint64 - { - var c0, c1, c2 uint64 - v := x[0] - u0, t0 = bits.Mul64(v, y[0]) - u1, t1 = bits.Mul64(v, y[1]) - u2, t2 = bits.Mul64(v, y[2]) - u3, t3 = bits.Mul64(v, y[3]) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, 0, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[1] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[2] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[3] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - z[0] = t0 - z[1] = t1 - z[2] = t2 - z[3] = t3 - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Square z = x * x (mod q) // // x must be less than q diff --git a/ecc/bls12-381/fp/element.go b/ecc/bls12-381/fp/element.go index 1609ca9524..f0bcfe51bc 100644 --- a/ecc/bls12-381/fp/element.go +++ b/ecc/bls12-381/fp/element.go @@ -418,6 +418,75 @@ func (z *Element) fromMont() *Element { return z } +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], carry = bits.Add64(x[3], y[3], carry) + z[4], carry = bits.Add64(x[4], y[4], carry) + z[5], _ = bits.Add64(x[5], y[5], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], carry = bits.Add64(x[3], x[3], carry) + z[4], carry = bits.Add64(x[4], x[4], carry) + z[5], _ = bits.Add64(x[5], x[5], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + z[4], b = bits.Sub64(x[4], y[4], b) + z[5], b = bits.Sub64(x[5], y[5], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + z[1], c = bits.Add64(z[1], q1, c) + z[2], c = bits.Add64(z[2], q2, c) + z[3], c = bits.Add64(z[3], q3, c) + z[4], c = bits.Add64(z[4], q4, c) + z[5], _ = bits.Add64(z[5], q5, c) + } + return z +} + // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bls12-381/fp/element_ops_purego.go b/ecc/bls12-381/fp/element_ops_purego.go index ecbfdb23cf..e9d87cc4a7 100644 --- a/ecc/bls12-381/fp/element_ops_purego.go +++ b/ecc/bls12-381/fp/element_ops_purego.go @@ -46,6 +46,15 @@ func MulBy13(x *Element) { x.Mul(x, &y) } +// TODO @gbotrel fixme. +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + func fromMont(z *Element) { _fromMontGeneric(z) } diff --git a/ecc/bls12-381/fr/element_ops_arm64.go b/ecc/bls12-381/fr/element_ops_arm64.go index 00cbb30388..9a57c7ca47 100644 --- a/ecc/bls12-381/fr/element_ops_arm64.go +++ b/ecc/bls12-381/fr/element_ops_arm64.go @@ -45,3 +45,11 @@ func (z *Element) Sub(x, y *Element) *Element { //go:noescape func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} diff --git a/ecc/bls12-381/fr/element_ops_arm64.s b/ecc/bls12-381/fr/element_ops_arm64.s index 1ae94dfd6a..5843be3b0a 100644 --- a/ecc/bls12-381/fr/element_ops_arm64.s +++ b/ecc/bls12-381/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 2806534830277291526 +// We include the hash to force the Go compiler to recompile: 9556532307245233012 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls12-381/fr/element_ops_purego.go b/ecc/bls12-381/fr/element_ops_purego.go index 7860a46621..ee62abea36 100644 --- a/ecc/bls12-381/fr/element_ops_purego.go +++ b/ecc/bls12-381/fr/element_ops_purego.go @@ -89,178 +89,6 @@ func (vector *Vector) Mul(a, b Vector) { mulVecGeneric(*vector, a, b) } -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - var t0, t1, t2, t3 uint64 - var u0, u1, u2, u3 uint64 - { - var c0, c1, c2 uint64 - v := x[0] - u0, t0 = bits.Mul64(v, y[0]) - u1, t1 = bits.Mul64(v, y[1]) - u2, t2 = bits.Mul64(v, y[2]) - u3, t3 = bits.Mul64(v, y[3]) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, 0, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[1] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[2] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[3] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - z[0] = t0 - z[1] = t1 - z[2] = t2 - z[3] = t3 - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Square z = x * x (mod q) // // x must be less than q diff --git a/ecc/bls24-315/fr/element_ops_arm64.go b/ecc/bls24-315/fr/element_ops_arm64.go index 00cbb30388..9a57c7ca47 100644 --- a/ecc/bls24-315/fr/element_ops_arm64.go +++ b/ecc/bls24-315/fr/element_ops_arm64.go @@ -45,3 +45,11 @@ func (z *Element) Sub(x, y *Element) *Element { //go:noescape func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} diff --git a/ecc/bls24-315/fr/element_ops_arm64.s b/ecc/bls24-315/fr/element_ops_arm64.s index 1ae94dfd6a..5843be3b0a 100644 --- a/ecc/bls24-315/fr/element_ops_arm64.s +++ b/ecc/bls24-315/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 2806534830277291526 +// We include the hash to force the Go compiler to recompile: 9556532307245233012 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-315/fr/element_ops_purego.go b/ecc/bls24-315/fr/element_ops_purego.go index 4eb7e31b87..e35913169e 100644 --- a/ecc/bls24-315/fr/element_ops_purego.go +++ b/ecc/bls24-315/fr/element_ops_purego.go @@ -89,178 +89,6 @@ func (vector *Vector) Mul(a, b Vector) { mulVecGeneric(*vector, a, b) } -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - var t0, t1, t2, t3 uint64 - var u0, u1, u2, u3 uint64 - { - var c0, c1, c2 uint64 - v := x[0] - u0, t0 = bits.Mul64(v, y[0]) - u1, t1 = bits.Mul64(v, y[1]) - u2, t2 = bits.Mul64(v, y[2]) - u3, t3 = bits.Mul64(v, y[3]) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, 0, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[1] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[2] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[3] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - z[0] = t0 - z[1] = t1 - z[2] = t2 - z[3] = t3 - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Square z = x * x (mod q) // // x must be less than q diff --git a/ecc/bls24-317/fr/element_ops_arm64.go b/ecc/bls24-317/fr/element_ops_arm64.go index 00cbb30388..9a57c7ca47 100644 --- a/ecc/bls24-317/fr/element_ops_arm64.go +++ b/ecc/bls24-317/fr/element_ops_arm64.go @@ -45,3 +45,11 @@ func (z *Element) Sub(x, y *Element) *Element { //go:noescape func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} diff --git a/ecc/bls24-317/fr/element_ops_arm64.s b/ecc/bls24-317/fr/element_ops_arm64.s index 1ae94dfd6a..5843be3b0a 100644 --- a/ecc/bls24-317/fr/element_ops_arm64.s +++ b/ecc/bls24-317/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 2806534830277291526 +// We include the hash to force the Go compiler to recompile: 9556532307245233012 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-317/fr/element_ops_purego.go b/ecc/bls24-317/fr/element_ops_purego.go index c0644ee03d..57e48f309b 100644 --- a/ecc/bls24-317/fr/element_ops_purego.go +++ b/ecc/bls24-317/fr/element_ops_purego.go @@ -89,178 +89,6 @@ func (vector *Vector) Mul(a, b Vector) { mulVecGeneric(*vector, a, b) } -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - var t0, t1, t2, t3 uint64 - var u0, u1, u2, u3 uint64 - { - var c0, c1, c2 uint64 - v := x[0] - u0, t0 = bits.Mul64(v, y[0]) - u1, t1 = bits.Mul64(v, y[1]) - u2, t2 = bits.Mul64(v, y[2]) - u3, t3 = bits.Mul64(v, y[3]) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, 0, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[1] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[2] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[3] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - z[0] = t0 - z[1] = t1 - z[2] = t2 - z[3] = t3 - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Square z = x * x (mod q) // // x must be less than q diff --git a/ecc/bn254/fp/element_ops_arm64.go b/ecc/bn254/fp/element_ops_arm64.go index 0133155d8e..1e0188e13a 100644 --- a/ecc/bn254/fp/element_ops_arm64.go +++ b/ecc/bn254/fp/element_ops_arm64.go @@ -45,3 +45,11 @@ func (z *Element) Sub(x, y *Element) *Element { //go:noescape func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} diff --git a/ecc/bn254/fp/element_ops_arm64.s b/ecc/bn254/fp/element_ops_arm64.s index 1ae94dfd6a..5843be3b0a 100644 --- a/ecc/bn254/fp/element_ops_arm64.s +++ b/ecc/bn254/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 2806534830277291526 +// We include the hash to force the Go compiler to recompile: 9556532307245233012 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fp/element_ops_purego.go b/ecc/bn254/fp/element_ops_purego.go index 1b404b5589..6d41d6578f 100644 --- a/ecc/bn254/fp/element_ops_purego.go +++ b/ecc/bn254/fp/element_ops_purego.go @@ -89,178 +89,6 @@ func (vector *Vector) Mul(a, b Vector) { mulVecGeneric(*vector, a, b) } -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - var t0, t1, t2, t3 uint64 - var u0, u1, u2, u3 uint64 - { - var c0, c1, c2 uint64 - v := x[0] - u0, t0 = bits.Mul64(v, y[0]) - u1, t1 = bits.Mul64(v, y[1]) - u2, t2 = bits.Mul64(v, y[2]) - u3, t3 = bits.Mul64(v, y[3]) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, 0, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[1] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[2] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[3] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - z[0] = t0 - z[1] = t1 - z[2] = t2 - z[3] = t3 - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Square z = x * x (mod q) // // x must be less than q diff --git a/ecc/bn254/fr/element_ops_arm64.go b/ecc/bn254/fr/element_ops_arm64.go index 00cbb30388..9a57c7ca47 100644 --- a/ecc/bn254/fr/element_ops_arm64.go +++ b/ecc/bn254/fr/element_ops_arm64.go @@ -45,3 +45,11 @@ func (z *Element) Sub(x, y *Element) *Element { //go:noescape func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} diff --git a/ecc/bn254/fr/element_ops_arm64.s b/ecc/bn254/fr/element_ops_arm64.s index 1ae94dfd6a..5843be3b0a 100644 --- a/ecc/bn254/fr/element_ops_arm64.s +++ b/ecc/bn254/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 2806534830277291526 +// We include the hash to force the Go compiler to recompile: 9556532307245233012 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fr/element_ops_purego.go b/ecc/bn254/fr/element_ops_purego.go index 721f5e7f39..859949859a 100644 --- a/ecc/bn254/fr/element_ops_purego.go +++ b/ecc/bn254/fr/element_ops_purego.go @@ -89,178 +89,6 @@ func (vector *Vector) Mul(a, b Vector) { mulVecGeneric(*vector, a, b) } -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - var t0, t1, t2, t3 uint64 - var u0, u1, u2, u3 uint64 - { - var c0, c1, c2 uint64 - v := x[0] - u0, t0 = bits.Mul64(v, y[0]) - u1, t1 = bits.Mul64(v, y[1]) - u2, t2 = bits.Mul64(v, y[2]) - u3, t3 = bits.Mul64(v, y[3]) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, 0, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[1] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[2] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[3] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - z[0] = t0 - z[1] = t1 - z[2] = t2 - z[3] = t3 - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Square z = x * x (mod q) // // x must be less than q diff --git a/ecc/bw6-761/fr/element.go b/ecc/bw6-761/fr/element.go index a887b71537..6784bc911f 100644 --- a/ecc/bw6-761/fr/element.go +++ b/ecc/bw6-761/fr/element.go @@ -418,6 +418,75 @@ func (z *Element) fromMont() *Element { return z } +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], carry = bits.Add64(x[3], y[3], carry) + z[4], carry = bits.Add64(x[4], y[4], carry) + z[5], _ = bits.Add64(x[5], y[5], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], carry = bits.Add64(x[3], x[3], carry) + z[4], carry = bits.Add64(x[4], x[4], carry) + z[5], _ = bits.Add64(x[5], x[5], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + z[4], b = bits.Sub64(x[4], y[4], b) + z[5], b = bits.Sub64(x[5], y[5], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + z[1], c = bits.Add64(z[1], q1, c) + z[2], c = bits.Add64(z[2], q2, c) + z[3], c = bits.Add64(z[3], q3, c) + z[4], c = bits.Add64(z[4], q4, c) + z[5], _ = bits.Add64(z[5], q5, c) + } + return z +} + // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bw6-761/fr/element_ops_purego.go b/ecc/bw6-761/fr/element_ops_purego.go index d9bc7f95a3..5090c3e9cd 100644 --- a/ecc/bw6-761/fr/element_ops_purego.go +++ b/ecc/bw6-761/fr/element_ops_purego.go @@ -46,6 +46,15 @@ func MulBy13(x *Element) { x.Mul(x, &y) } +// TODO @gbotrel fixme. +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + func fromMont(z *Element) { _fromMontGeneric(z) } diff --git a/ecc/stark-curve/fp/element_ops_arm64.go b/ecc/stark-curve/fp/element_ops_arm64.go index 0133155d8e..1e0188e13a 100644 --- a/ecc/stark-curve/fp/element_ops_arm64.go +++ b/ecc/stark-curve/fp/element_ops_arm64.go @@ -45,3 +45,11 @@ func (z *Element) Sub(x, y *Element) *Element { //go:noescape func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} diff --git a/ecc/stark-curve/fp/element_ops_arm64.s b/ecc/stark-curve/fp/element_ops_arm64.s index 1ae94dfd6a..5843be3b0a 100644 --- a/ecc/stark-curve/fp/element_ops_arm64.s +++ b/ecc/stark-curve/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 2806534830277291526 +// We include the hash to force the Go compiler to recompile: 9556532307245233012 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fp/element_ops_purego.go b/ecc/stark-curve/fp/element_ops_purego.go index 3b2b30110f..189eb1054a 100644 --- a/ecc/stark-curve/fp/element_ops_purego.go +++ b/ecc/stark-curve/fp/element_ops_purego.go @@ -89,178 +89,6 @@ func (vector *Vector) Mul(a, b Vector) { mulVecGeneric(*vector, a, b) } -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - var t0, t1, t2, t3 uint64 - var u0, u1, u2, u3 uint64 - { - var c0, c1, c2 uint64 - v := x[0] - u0, t0 = bits.Mul64(v, y[0]) - u1, t1 = bits.Mul64(v, y[1]) - u2, t2 = bits.Mul64(v, y[2]) - u3, t3 = bits.Mul64(v, y[3]) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, 0, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[1] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[2] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[3] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - z[0] = t0 - z[1] = t1 - z[2] = t2 - z[3] = t3 - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Square z = x * x (mod q) // // x must be less than q diff --git a/ecc/stark-curve/fr/element_ops_arm64.go b/ecc/stark-curve/fr/element_ops_arm64.go index 00cbb30388..9a57c7ca47 100644 --- a/ecc/stark-curve/fr/element_ops_arm64.go +++ b/ecc/stark-curve/fr/element_ops_arm64.go @@ -45,3 +45,11 @@ func (z *Element) Sub(x, y *Element) *Element { //go:noescape func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} diff --git a/ecc/stark-curve/fr/element_ops_arm64.s b/ecc/stark-curve/fr/element_ops_arm64.s index 1ae94dfd6a..5843be3b0a 100644 --- a/ecc/stark-curve/fr/element_ops_arm64.s +++ b/ecc/stark-curve/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 2806534830277291526 +// We include the hash to force the Go compiler to recompile: 9556532307245233012 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fr/element_ops_purego.go b/ecc/stark-curve/fr/element_ops_purego.go index 2a57b19d86..ac4346db5b 100644 --- a/ecc/stark-curve/fr/element_ops_purego.go +++ b/ecc/stark-curve/fr/element_ops_purego.go @@ -89,178 +89,6 @@ func (vector *Vector) Mul(a, b Vector) { mulVecGeneric(*vector, a, b) } -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - var t0, t1, t2, t3 uint64 - var u0, u1, u2, u3 uint64 - { - var c0, c1, c2 uint64 - v := x[0] - u0, t0 = bits.Mul64(v, y[0]) - u1, t1 = bits.Mul64(v, y[1]) - u2, t2 = bits.Mul64(v, y[2]) - u3, t3 = bits.Mul64(v, y[3]) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, 0, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[1] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[2] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[3] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - z[0] = t0 - z[1] = t1 - z[2] = t2 - z[3] = t3 - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} - // Square z = x * x (mod q) // // x must be less than q diff --git a/field/asm/element_4w_arm64.s b/field/asm/element_4w_arm64.s index 76a18f25a5..be6a78fc3f 100644 --- a/field/asm/element_4w_arm64.s +++ b/field/asm/element_4w_arm64.s @@ -144,3 +144,91 @@ TEXT ·Butterfly(SB), NOSPLIT, $0-16 STP (R4, R5), 0(R17) STP (R6, R7), 16(R17) RET + +// mul(res, x, y *Element) +TEXT ·mul(SB), NOSPLIT, $0-24 + LDP x+8(FP), (R0, R1) + LDP 0(R0), (R3, R4) + LDP 16(R0), (R5, R6) + LDP ·qElement+0(SB), (R12, R13) + LDP ·qElement+16(SB), (R14, R15) + +#define DIVSHIFT() \ + MOVD $const_qInvNeg, R2 \ + MUL R7, R2, R2 \ + MUL R12, R2, R0 \ + ADDS R0, R7, R7 \ + MUL R13, R2, R0 \ + ADCS R0, R8, R8 \ + MUL R14, R2, R0 \ + ADCS R0, R9, R9 \ + MUL R15, R2, R0 \ + ADCS R0, R10, R10 \ + ADCS ZR, R11, R11 \ + UMULH R12, R2, R0 \ + ADDS R0, R8, R7 \ + UMULH R13, R2, R0 \ + ADCS R0, R9, R8 \ + UMULH R14, R2, R0 \ + ADCS R0, R10, R9 \ + UMULH R15, R2, R0 \ + ADCS R0, R11, R10 \ + +#define MUL_WORD_N() \ + MUL R3, R2, R0 \ + ADDS R0, R7, R7 \ + MUL R4, R2, R0 \ + ADCS R0, R8, R8 \ + MUL R5, R2, R0 \ + ADCS R0, R9, R9 \ + MUL R6, R2, R0 \ + ADCS R0, R10, R10 \ + ADCS ZR, ZR, R11 \ + UMULH R3, R2, R0 \ + ADDS R0, R8, R8 \ + UMULH R4, R2, R0 \ + ADCS R0, R9, R9 \ + UMULH R5, R2, R0 \ + ADCS R0, R10, R10 \ + UMULH R6, R2, R0 \ + ADCS R0, R11, R11 \ + DIVSHIFT() \ + +#define MUL_WORD_0() \ + MUL R3, R2, R7 \ + MUL R4, R2, R8 \ + MUL R5, R2, R9 \ + MUL R6, R2, R10 \ + UMULH R3, R2, R0 \ + ADDS R0, R8, R8 \ + UMULH R4, R2, R0 \ + ADCS R0, R9, R9 \ + UMULH R5, R2, R0 \ + ADCS R0, R10, R10 \ + UMULH R6, R2, R0 \ + ADCS ZR, R0, R11 \ + DIVSHIFT() \ + + // mul body + MOVD 0(R1), R2 + MUL_WORD_0() + MOVD 8(R1), R2 + MUL_WORD_N() + MOVD 16(R1), R2 + MUL_WORD_N() + MOVD 24(R1), R2 + MUL_WORD_N() + + // reduce if necessary + SUBS R12, R7, R12 + SBCS R13, R8, R13 + SBCS R14, R9, R14 + SBCS R15, R10, R15 + CSEL CS, R12, R7, R7 + CSEL CS, R13, R8, R8 + CSEL CS, R14, R9, R9 + CSEL CS, R15, R10, R10 + MOVD res+0(FP), R0 + STP (R7, R8), 0(R0) + STP (R9, R10), 16(R0) + RET diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index 6824a06edb..3734a8078a 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -112,6 +112,8 @@ func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { f.generateSub() f.generateButterfly() + f.generateMul() + return nil } diff --git a/field/generator/asm/arm64/element_ops.go b/field/generator/asm/arm64/element_ops.go index d75d21e2fa..f865804a9a 100644 --- a/field/generator/asm/arm64/element_ops.go +++ b/field/generator/asm/arm64/element_ops.go @@ -188,25 +188,165 @@ func (f *FFArm64) generateButterfly() { f.RET() } -func (f *FFArm64) reduce(z, t []arm64.Register) { +func (f *FFArm64) generateMul() { + f.Comment("mul(res, x, y *Element)") + registers := f.FnHeader("mul", 0, 24) + defer f.AssertCleanStack(0, 0) + + xPtr := registers.Pop() + yPtr := registers.Pop() + bi := registers.Pop() + a := registers.PopN(f.NbWords) + t := registers.PopN(f.NbWords + 1) + q := registers.PopN(f.NbWords) + + f.LDP("x+8(FP)", xPtr, yPtr) + + f.load(xPtr, a) + ax := xPtr + // f.load(yPtr, y) + + for i := 0; i < f.NbWords-1; i += 2 { + f.LDP(f.qAt(i), q[i], q[i+1]) + } + + divShift := f.Define("divShift", 0, func(args ...arm64.Register) { + m := bi + f.MOVD(f.qInv0(), m) + f.MUL(t[0], m, m) + + // for j=0 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + + for j := 0; j < f.NbWords; j++ { + f.MUL(q[j], m, ax) + if j == 0 { + f.ADDS(ax, t[j], t[j]) + } else { + f.ADCS(ax, t[j], t[j]) + } + } + f.ADCS("ZR", t[f.NbWords], t[f.NbWords]) + + // propagate high bits + f.UMULH(q[0], m, ax) + for j := 1; j <= f.NbWords; j++ { + if j == 1 { + f.ADDS(ax, t[j], t[j-1]) + } else { + f.ADCS(ax, t[j], t[j-1]) + } + if j != f.NbWords { + f.UMULH(q[j], m, ax) + } + } + }) + + mulWordN := f.Define("MUL_WORD_N", 0, func(args ...arm64.Register) { + // for j=0 to N-1 + // (C,t[j]) := t[j] + a[j]*b[i] + C + + // lo bits + for j := 0; j < f.NbWords; j++ { + f.MUL(a[j], bi, ax) + + if j == 0 { + f.ADDS(ax, t[j], t[j]) + } else { + f.ADCS(ax, t[j], t[j]) + } + } + + f.ADCS("ZR", "ZR", t[f.NbWords]) + + // propagate high bits + f.UMULH(a[0], bi, ax) + for j := 1; j <= f.NbWords; j++ { + if j == 1 { + f.ADDS(ax, t[j], t[j]) + } else { + f.ADCS(ax, t[j], t[j]) + } + if j != f.NbWords { + f.UMULH(a[j], bi, ax) + } + } + divShift() + }) + + mulWord0 := f.Define("MUL_WORD_0", 0, func(args ...arm64.Register) { + // for j=0 to N-1 + // (C,t[j]) := t[j] + a[j]*b[i] + C + + // lo bits + for j := 0; j < f.NbWords; j++ { + f.MUL(a[j], bi, t[j]) + } + + // propagate high bits + f.UMULH(a[0], bi, ax) + for j := 1; j <= f.NbWords; j++ { + if j == 1 { + f.ADDS(ax, t[j], t[j]) + } else { + if j == f.NbWords { + f.ADCS("ZR", ax, t[j]) + } else { + f.ADCS(ax, t[j], t[j]) + } + } + if j != f.NbWords { + f.UMULH(a[j], bi, ax) + } + } + divShift() + }) + + f.Comment("mul body") + + for i := 0; i < f.NbWords; i++ { + f.MOVD(yPtr.At(i), bi) - if len(z) != f.NbWords || len(t) != f.NbWords { + if i == 0 { + mulWord0() + } else { + mulWordN() + } + } + f.Comment("reduce if necessary") + f.SUBS(q[0], t[0], q[0]) + for i := 1; i < f.NbWords; i++ { + f.SBCS(q[i], t[i], q[i]) + } + for i := 0; i < f.NbWords; i++ { + f.CSEL("CS", q[i], t[i], t[i]) + } + + f.MOVD("res+0(FP)", xPtr) + f.store(t[:f.NbWords], xPtr) + + f.RET() +} + +func (f *FFArm64) reduce(t, q []arm64.Register) { + + if len(t) != f.NbWords || len(q) != f.NbWords { panic("need 2*nbWords registers") } f.Comment("load modulus and subtract") for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(f.qAt(i), t[i], t[i+1]) + f.LDP(f.qAt(i), q[i], q[i+1]) } - f.SUBS(t[0], z[0], t[0]) + f.SUBS(q[0], t[0], q[0]) for i := 1; i < f.NbWords; i++ { - f.SBCS(t[i], z[i], t[i]) + f.SBCS(q[i], t[i], q[i]) } f.Comment("reduce if necessary") for i := 0; i < f.NbWords; i++ { - f.CSEL("CS", t[i], z[i], z[i]) + f.CSEL("CS", q[i], t[i], t[i]) } } diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index f9617f6e1d..9b0d8d47d0 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -264,7 +264,7 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) // asm code generation for moduli with more than 6 words can be optimized further F.ASM = F.NoCarry && F.NbWords <= 12 && F.NbWords > 1 F.ASMVector = F.ASM && F.NbWords == 4 && F.NbBits > 225 - F.ASMArm = F.ASMVector || (F.NbWords == 6) + F.ASMArm = F.ASMVector // setting Mu 2^288 / q if F.NbWords == 4 { diff --git a/field/generator/internal/templates/element/ops_asm.go b/field/generator/internal/templates/element/ops_asm.go index 5f596cfcf4..67a7836c65 100644 --- a/field/generator/internal/templates/element/ops_asm.go +++ b/field/generator/internal/templates/element/ops_asm.go @@ -226,5 +226,13 @@ func (z *{{.ElementName}}) Sub(x, y *{{.ElementName}}) *{{.ElementName}} { //go:noescape func Butterfly(a, b *{{.ElementName}}) +//go:noescape +func mul(res,x,y *{{.ElementName}}) + +func (z *{{.ElementName}}) Mul(x, y *{{.ElementName}}) *{{.ElementName}} { + mul(z,x,y) + return z +} + {{end}} ` diff --git a/field/generator/internal/templates/element/ops_purego.go b/field/generator/internal/templates/element/ops_purego.go index cce771aa1b..de8fdc263e 100644 --- a/field/generator/internal/templates/element/ops_purego.go +++ b/field/generator/internal/templates/element/ops_purego.go @@ -93,6 +93,8 @@ func (vector *Vector) Mul(a, b Vector) { {{- end}} + +{{- if not .ASMArm}} // Mul z = x * y (mod q) {{- if $.NoCarry}} // @@ -112,6 +114,7 @@ func (z *{{.ElementName}}) Mul(x, y *{{.ElementName}}) *{{.ElementName}} { {{- end }} return z } +{{- end}} // Square z = x * x (mod q) {{- if $.NoCarry}} From 6fbfb4874f0556ba39e54f065d38dd97cb12c215 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Wed, 16 Oct 2024 20:11:22 +0000 Subject: [PATCH 11/74] checkpoint --- ecc/bls12-377/fp/element.go | 69 ------ ecc/bls12-377/fp/element_ops_arm64.go | 55 ++++ ecc/bls12-377/fp/element_ops_arm64.s | 6 + ecc/bls12-377/fp/element_ops_purego.go | 331 ------------------------- ecc/bls12-381/fp/element.go | 69 ------ ecc/bls12-381/fp/element_ops_arm64.go | 55 ++++ ecc/bls12-381/fp/element_ops_arm64.s | 6 + ecc/bls12-381/fp/element_ops_purego.go | 331 ------------------------- ecc/bw6-761/fr/element.go | 69 ------ ecc/bw6-761/fr/element_ops_arm64.go | 55 ++++ ecc/bw6-761/fr/element_ops_arm64.s | 6 + ecc/bw6-761/fr/element_ops_purego.go | 331 ------------------------- field/asm/element_6w_arm64.s | 314 +++++++++++++++++++++++ field/generator/config/field_config.go | 2 +- 14 files changed, 498 insertions(+), 1201 deletions(-) create mode 100644 ecc/bls12-377/fp/element_ops_arm64.go create mode 100644 ecc/bls12-377/fp/element_ops_arm64.s create mode 100644 ecc/bls12-381/fp/element_ops_arm64.go create mode 100644 ecc/bls12-381/fp/element_ops_arm64.s create mode 100644 ecc/bw6-761/fr/element_ops_arm64.go create mode 100644 ecc/bw6-761/fr/element_ops_arm64.s create mode 100644 field/asm/element_6w_arm64.s diff --git a/ecc/bls12-377/fp/element.go b/ecc/bls12-377/fp/element.go index 393f45744d..408b94653d 100644 --- a/ecc/bls12-377/fp/element.go +++ b/ecc/bls12-377/fp/element.go @@ -418,75 +418,6 @@ func (z *Element) fromMont() *Element { return z } -// Add z = x + y (mod q) -func (z *Element) Add(x, y *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], y[0], 0) - z[1], carry = bits.Add64(x[1], y[1], carry) - z[2], carry = bits.Add64(x[2], y[2], carry) - z[3], carry = bits.Add64(x[3], y[3], carry) - z[4], carry = bits.Add64(x[4], y[4], carry) - z[5], _ = bits.Add64(x[5], y[5], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], b = bits.Sub64(z[3], q3, b) - z[4], b = bits.Sub64(z[4], q4, b) - z[5], _ = bits.Sub64(z[5], q5, b) - } - return z -} - -// Double z = x + x (mod q), aka Lsh 1 -func (z *Element) Double(x *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], x[0], 0) - z[1], carry = bits.Add64(x[1], x[1], carry) - z[2], carry = bits.Add64(x[2], x[2], carry) - z[3], carry = bits.Add64(x[3], x[3], carry) - z[4], carry = bits.Add64(x[4], x[4], carry) - z[5], _ = bits.Add64(x[5], x[5], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], b = bits.Sub64(z[3], q3, b) - z[4], b = bits.Sub64(z[4], q4, b) - z[5], _ = bits.Sub64(z[5], q5, b) - } - return z -} - -// Sub z = x - y (mod q) -func (z *Element) Sub(x, y *Element) *Element { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) - z[1], b = bits.Sub64(x[1], y[1], b) - z[2], b = bits.Sub64(x[2], y[2], b) - z[3], b = bits.Sub64(x[3], y[3], b) - z[4], b = bits.Sub64(x[4], y[4], b) - z[5], b = bits.Sub64(x[5], y[5], b) - if b != 0 { - var c uint64 - z[0], c = bits.Add64(z[0], q0, 0) - z[1], c = bits.Add64(z[1], q1, c) - z[2], c = bits.Add64(z[2], q2, c) - z[3], c = bits.Add64(z[3], q3, c) - z[4], c = bits.Add64(z[4], q4, c) - z[5], _ = bits.Add64(z[5], q5, c) - } - return z -} - // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bls12-377/fp/element_ops_arm64.go b/ecc/bls12-377/fp/element_ops_arm64.go new file mode 100644 index 0000000000..1e0188e13a --- /dev/null +++ b/ecc/bls12-377/fp/element_ops_arm64.go @@ -0,0 +1,55 @@ +//go:build !purego +// +build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +//go:noescape +func add(res, x, y *Element) + +func (z *Element) Add(x, y *Element) *Element { + add(z, x, y) + return z +} + +//go:noescape +func double(res, x *Element) + +func (z *Element) Double(x *Element) *Element { + double(z, x) + return z +} + +//go:noescape +func sub(res, x, y *Element) + +func (z *Element) Sub(x, y *Element) *Element { + sub(z, x, y) + return z +} + +//go:noescape +func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} diff --git a/ecc/bls12-377/fp/element_ops_arm64.s b/ecc/bls12-377/fp/element_ops_arm64.s new file mode 100644 index 0000000000..8fc2d40460 --- /dev/null +++ b/ecc/bls12-377/fp/element_ops_arm64.s @@ -0,0 +1,6 @@ +// +build !purego + +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 6977593885298654654 +#include "../../../field/asm/element_6w_arm64.s" + diff --git a/ecc/bls12-377/fp/element_ops_purego.go b/ecc/bls12-377/fp/element_ops_purego.go index 5b720d9fde..dcff0bf7a3 100644 --- a/ecc/bls12-377/fp/element_ops_purego.go +++ b/ecc/bls12-377/fp/element_ops_purego.go @@ -46,15 +46,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// TODO @gbotrel fixme. -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } @@ -63,328 +54,6 @@ func reduce(z *Element) { _reduceGeneric(z) } -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - var t0, t1, t2, t3, t4, t5 uint64 - var u0, u1, u2, u3, u4, u5 uint64 - { - var c0, c1, c2 uint64 - v := x[0] - u0, t0 = bits.Mul64(v, y[0]) - u1, t1 = bits.Mul64(v, y[1]) - u2, t2 = bits.Mul64(v, y[2]) - u3, t3 = bits.Mul64(v, y[3]) - u4, t4 = bits.Mul64(v, y[4]) - u5, t5 = bits.Mul64(v, y[5]) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, 0, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) - c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[1] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, y[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, y[5]) - t5, c0 = bits.Add64(c1, t5, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) - c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[2] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, y[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, y[5]) - t5, c0 = bits.Add64(c1, t5, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) - c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[3] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, y[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, y[5]) - t5, c0 = bits.Add64(c1, t5, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) - c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[4] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, y[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, y[5]) - t5, c0 = bits.Add64(c1, t5, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) - c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[5] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, y[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, y[5]) - t5, c0 = bits.Add64(c1, t5, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) - c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) - - } - z[0] = t0 - z[1] = t1 - z[2] = t2 - z[3] = t3 - z[4] = t4 - z[5] = t5 - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], b = bits.Sub64(z[3], q3, b) - z[4], b = bits.Sub64(z[4], q4, b) - z[5], _ = bits.Sub64(z[5], q5, b) - } - return z -} - // Square z = x * x (mod q) // // x must be less than q diff --git a/ecc/bls12-381/fp/element.go b/ecc/bls12-381/fp/element.go index f0bcfe51bc..1609ca9524 100644 --- a/ecc/bls12-381/fp/element.go +++ b/ecc/bls12-381/fp/element.go @@ -418,75 +418,6 @@ func (z *Element) fromMont() *Element { return z } -// Add z = x + y (mod q) -func (z *Element) Add(x, y *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], y[0], 0) - z[1], carry = bits.Add64(x[1], y[1], carry) - z[2], carry = bits.Add64(x[2], y[2], carry) - z[3], carry = bits.Add64(x[3], y[3], carry) - z[4], carry = bits.Add64(x[4], y[4], carry) - z[5], _ = bits.Add64(x[5], y[5], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], b = bits.Sub64(z[3], q3, b) - z[4], b = bits.Sub64(z[4], q4, b) - z[5], _ = bits.Sub64(z[5], q5, b) - } - return z -} - -// Double z = x + x (mod q), aka Lsh 1 -func (z *Element) Double(x *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], x[0], 0) - z[1], carry = bits.Add64(x[1], x[1], carry) - z[2], carry = bits.Add64(x[2], x[2], carry) - z[3], carry = bits.Add64(x[3], x[3], carry) - z[4], carry = bits.Add64(x[4], x[4], carry) - z[5], _ = bits.Add64(x[5], x[5], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], b = bits.Sub64(z[3], q3, b) - z[4], b = bits.Sub64(z[4], q4, b) - z[5], _ = bits.Sub64(z[5], q5, b) - } - return z -} - -// Sub z = x - y (mod q) -func (z *Element) Sub(x, y *Element) *Element { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) - z[1], b = bits.Sub64(x[1], y[1], b) - z[2], b = bits.Sub64(x[2], y[2], b) - z[3], b = bits.Sub64(x[3], y[3], b) - z[4], b = bits.Sub64(x[4], y[4], b) - z[5], b = bits.Sub64(x[5], y[5], b) - if b != 0 { - var c uint64 - z[0], c = bits.Add64(z[0], q0, 0) - z[1], c = bits.Add64(z[1], q1, c) - z[2], c = bits.Add64(z[2], q2, c) - z[3], c = bits.Add64(z[3], q3, c) - z[4], c = bits.Add64(z[4], q4, c) - z[5], _ = bits.Add64(z[5], q5, c) - } - return z -} - // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bls12-381/fp/element_ops_arm64.go b/ecc/bls12-381/fp/element_ops_arm64.go new file mode 100644 index 0000000000..1e0188e13a --- /dev/null +++ b/ecc/bls12-381/fp/element_ops_arm64.go @@ -0,0 +1,55 @@ +//go:build !purego +// +build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +//go:noescape +func add(res, x, y *Element) + +func (z *Element) Add(x, y *Element) *Element { + add(z, x, y) + return z +} + +//go:noescape +func double(res, x *Element) + +func (z *Element) Double(x *Element) *Element { + double(z, x) + return z +} + +//go:noescape +func sub(res, x, y *Element) + +func (z *Element) Sub(x, y *Element) *Element { + sub(z, x, y) + return z +} + +//go:noescape +func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} diff --git a/ecc/bls12-381/fp/element_ops_arm64.s b/ecc/bls12-381/fp/element_ops_arm64.s new file mode 100644 index 0000000000..8fc2d40460 --- /dev/null +++ b/ecc/bls12-381/fp/element_ops_arm64.s @@ -0,0 +1,6 @@ +// +build !purego + +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 6977593885298654654 +#include "../../../field/asm/element_6w_arm64.s" + diff --git a/ecc/bls12-381/fp/element_ops_purego.go b/ecc/bls12-381/fp/element_ops_purego.go index e9d87cc4a7..e818762b17 100644 --- a/ecc/bls12-381/fp/element_ops_purego.go +++ b/ecc/bls12-381/fp/element_ops_purego.go @@ -46,15 +46,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// TODO @gbotrel fixme. -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } @@ -63,328 +54,6 @@ func reduce(z *Element) { _reduceGeneric(z) } -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - var t0, t1, t2, t3, t4, t5 uint64 - var u0, u1, u2, u3, u4, u5 uint64 - { - var c0, c1, c2 uint64 - v := x[0] - u0, t0 = bits.Mul64(v, y[0]) - u1, t1 = bits.Mul64(v, y[1]) - u2, t2 = bits.Mul64(v, y[2]) - u3, t3 = bits.Mul64(v, y[3]) - u4, t4 = bits.Mul64(v, y[4]) - u5, t5 = bits.Mul64(v, y[5]) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, 0, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) - c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[1] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, y[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, y[5]) - t5, c0 = bits.Add64(c1, t5, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) - c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[2] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, y[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, y[5]) - t5, c0 = bits.Add64(c1, t5, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) - c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[3] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, y[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, y[5]) - t5, c0 = bits.Add64(c1, t5, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) - c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[4] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, y[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, y[5]) - t5, c0 = bits.Add64(c1, t5, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) - c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[5] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, y[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, y[5]) - t5, c0 = bits.Add64(c1, t5, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) - c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) - - } - z[0] = t0 - z[1] = t1 - z[2] = t2 - z[3] = t3 - z[4] = t4 - z[5] = t5 - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], b = bits.Sub64(z[3], q3, b) - z[4], b = bits.Sub64(z[4], q4, b) - z[5], _ = bits.Sub64(z[5], q5, b) - } - return z -} - // Square z = x * x (mod q) // // x must be less than q diff --git a/ecc/bw6-761/fr/element.go b/ecc/bw6-761/fr/element.go index 6784bc911f..a887b71537 100644 --- a/ecc/bw6-761/fr/element.go +++ b/ecc/bw6-761/fr/element.go @@ -418,75 +418,6 @@ func (z *Element) fromMont() *Element { return z } -// Add z = x + y (mod q) -func (z *Element) Add(x, y *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], y[0], 0) - z[1], carry = bits.Add64(x[1], y[1], carry) - z[2], carry = bits.Add64(x[2], y[2], carry) - z[3], carry = bits.Add64(x[3], y[3], carry) - z[4], carry = bits.Add64(x[4], y[4], carry) - z[5], _ = bits.Add64(x[5], y[5], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], b = bits.Sub64(z[3], q3, b) - z[4], b = bits.Sub64(z[4], q4, b) - z[5], _ = bits.Sub64(z[5], q5, b) - } - return z -} - -// Double z = x + x (mod q), aka Lsh 1 -func (z *Element) Double(x *Element) *Element { - - var carry uint64 - z[0], carry = bits.Add64(x[0], x[0], 0) - z[1], carry = bits.Add64(x[1], x[1], carry) - z[2], carry = bits.Add64(x[2], x[2], carry) - z[3], carry = bits.Add64(x[3], x[3], carry) - z[4], carry = bits.Add64(x[4], x[4], carry) - z[5], _ = bits.Add64(x[5], x[5], carry) - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], b = bits.Sub64(z[3], q3, b) - z[4], b = bits.Sub64(z[4], q4, b) - z[5], _ = bits.Sub64(z[5], q5, b) - } - return z -} - -// Sub z = x - y (mod q) -func (z *Element) Sub(x, y *Element) *Element { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) - z[1], b = bits.Sub64(x[1], y[1], b) - z[2], b = bits.Sub64(x[2], y[2], b) - z[3], b = bits.Sub64(x[3], y[3], b) - z[4], b = bits.Sub64(x[4], y[4], b) - z[5], b = bits.Sub64(x[5], y[5], b) - if b != 0 { - var c uint64 - z[0], c = bits.Add64(z[0], q0, 0) - z[1], c = bits.Add64(z[1], q1, c) - z[2], c = bits.Add64(z[2], q2, c) - z[3], c = bits.Add64(z[3], q3, c) - z[4], c = bits.Add64(z[4], q4, c) - z[5], _ = bits.Add64(z[5], q5, c) - } - return z -} - // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bw6-761/fr/element_ops_arm64.go b/ecc/bw6-761/fr/element_ops_arm64.go new file mode 100644 index 0000000000..9a57c7ca47 --- /dev/null +++ b/ecc/bw6-761/fr/element_ops_arm64.go @@ -0,0 +1,55 @@ +//go:build !purego +// +build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +//go:noescape +func add(res, x, y *Element) + +func (z *Element) Add(x, y *Element) *Element { + add(z, x, y) + return z +} + +//go:noescape +func double(res, x *Element) + +func (z *Element) Double(x *Element) *Element { + double(z, x) + return z +} + +//go:noescape +func sub(res, x, y *Element) + +func (z *Element) Sub(x, y *Element) *Element { + sub(z, x, y) + return z +} + +//go:noescape +func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} diff --git a/ecc/bw6-761/fr/element_ops_arm64.s b/ecc/bw6-761/fr/element_ops_arm64.s new file mode 100644 index 0000000000..8fc2d40460 --- /dev/null +++ b/ecc/bw6-761/fr/element_ops_arm64.s @@ -0,0 +1,6 @@ +// +build !purego + +// Code generated by gnark-crypto/generator. DO NOT EDIT. +// We include the hash to force the Go compiler to recompile: 6977593885298654654 +#include "../../../field/asm/element_6w_arm64.s" + diff --git a/ecc/bw6-761/fr/element_ops_purego.go b/ecc/bw6-761/fr/element_ops_purego.go index 5090c3e9cd..df3ef00ace 100644 --- a/ecc/bw6-761/fr/element_ops_purego.go +++ b/ecc/bw6-761/fr/element_ops_purego.go @@ -46,15 +46,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// TODO @gbotrel fixme. -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } @@ -63,328 +54,6 @@ func reduce(z *Element) { _reduceGeneric(z) } -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - var t0, t1, t2, t3, t4, t5 uint64 - var u0, u1, u2, u3, u4, u5 uint64 - { - var c0, c1, c2 uint64 - v := x[0] - u0, t0 = bits.Mul64(v, y[0]) - u1, t1 = bits.Mul64(v, y[1]) - u2, t2 = bits.Mul64(v, y[2]) - u3, t3 = bits.Mul64(v, y[3]) - u4, t4 = bits.Mul64(v, y[4]) - u5, t5 = bits.Mul64(v, y[5]) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, 0, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) - c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[1] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, y[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, y[5]) - t5, c0 = bits.Add64(c1, t5, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) - c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[2] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, y[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, y[5]) - t5, c0 = bits.Add64(c1, t5, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) - c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[3] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, y[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, y[5]) - t5, c0 = bits.Add64(c1, t5, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) - c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[4] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, y[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, y[5]) - t5, c0 = bits.Add64(c1, t5, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) - c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[5] - u0, c1 = bits.Mul64(v, y[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, y[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, y[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, y[3]) - t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, y[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, y[5]) - t5, c0 = bits.Add64(c1, t5, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) - c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) - - } - z[0] = t0 - z[1] = t1 - z[2] = t2 - z[3] = t3 - z[4] = t4 - z[5] = t5 - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], b = bits.Sub64(z[3], q3, b) - z[4], b = bits.Sub64(z[4], q4, b) - z[5], _ = bits.Sub64(z[5], q5, b) - } - return z -} - // Square z = x * x (mod q) // // x must be less than q diff --git a/field/asm/element_6w_arm64.s b/field/asm/element_6w_arm64.s new file mode 100644 index 0000000000..7705fc52d7 --- /dev/null +++ b/field/asm/element_6w_arm64.s @@ -0,0 +1,314 @@ +// Code generated by gnark-crypto/generator. DO NOT EDIT. +#include "textflag.h" +#include "funcdata.h" +#include "go_asm.h" + +// add(res, x, y *Element) +TEXT ·add(SB), NOSPLIT, $0-24 + LDP x+8(FP), (R19, R20) + LDP 0(R19), (R12, R13) + LDP 16(R19), (R14, R15) + LDP 32(R19), (R16, R17) + LDP 0(R20), (R6, R7) + LDP 16(R20), (R8, R9) + LDP 32(R20), (R10, R11) + ADDS R12, R6, R6 + ADCS R13, R7, R7 + ADCS R14, R8, R8 + ADCS R15, R9, R9 + ADCS R16, R10, R10 + ADCS R17, R11, R11 + + // load modulus and subtract + LDP ·qElement+0(SB), (R0, R1) + LDP ·qElement+16(SB), (R2, R3) + LDP ·qElement+32(SB), (R4, R5) + SUBS R0, R6, R0 + SBCS R1, R7, R1 + SBCS R2, R8, R2 + SBCS R3, R9, R3 + SBCS R4, R10, R4 + SBCS R5, R11, R5 + + // reduce if necessary + CSEL CS, R0, R6, R6 + CSEL CS, R1, R7, R7 + CSEL CS, R2, R8, R8 + CSEL CS, R3, R9, R9 + CSEL CS, R4, R10, R10 + CSEL CS, R5, R11, R11 + + // store + MOVD res+0(FP), R21 + STP (R6, R7), 0(R21) + STP (R8, R9), 16(R21) + STP (R10, R11), 32(R21) + RET + +// double(res, x *Element) +TEXT ·double(SB), NOSPLIT, $0-16 + LDP res+0(FP), (R1, R0) + LDP 0(R0), (R2, R3) + LDP 16(R0), (R4, R5) + LDP 32(R0), (R6, R7) + ADDS R2, R2, R2 + ADCS R3, R3, R3 + ADCS R4, R4, R4 + ADCS R5, R5, R5 + ADCS R6, R6, R6 + ADCS R7, R7, R7 + + // load modulus and subtract + LDP ·qElement+0(SB), (R8, R9) + LDP ·qElement+16(SB), (R10, R11) + LDP ·qElement+32(SB), (R12, R13) + SUBS R8, R2, R8 + SBCS R9, R3, R9 + SBCS R10, R4, R10 + SBCS R11, R5, R11 + SBCS R12, R6, R12 + SBCS R13, R7, R13 + + // reduce if necessary + CSEL CS, R8, R2, R2 + CSEL CS, R9, R3, R3 + CSEL CS, R10, R4, R4 + CSEL CS, R11, R5, R5 + CSEL CS, R12, R6, R6 + CSEL CS, R13, R7, R7 + STP (R2, R3), 0(R1) + STP (R4, R5), 16(R1) + STP (R6, R7), 32(R1) + RET + +// sub(res, x, y *Element) +TEXT ·sub(SB), NOSPLIT, $0-24 + LDP x+8(FP), (R19, R20) + LDP 0(R19), (R6, R7) + LDP 16(R19), (R8, R9) + LDP 32(R19), (R10, R11) + LDP 0(R20), (R0, R1) + LDP 16(R20), (R2, R3) + LDP 32(R20), (R4, R5) + SUBS R0, R6, R0 + SBCS R1, R7, R1 + SBCS R2, R8, R2 + SBCS R3, R9, R3 + SBCS R4, R10, R4 + SBCS R5, R11, R5 + + // load modulus and select + LDP ·qElement+0(SB), (R12, R13) + LDP ·qElement+16(SB), (R14, R15) + LDP ·qElement+32(SB), (R16, R17) + CSEL CS, ZR, R12, R12 + CSEL CS, ZR, R13, R13 + CSEL CS, ZR, R14, R14 + CSEL CS, ZR, R15, R15 + CSEL CS, ZR, R16, R16 + CSEL CS, ZR, R17, R17 + + // add q if underflow, 0 if not + ADDS R0, R12, R0 + ADCS R1, R13, R1 + ADCS R2, R14, R2 + ADCS R3, R15, R3 + ADCS R4, R16, R4 + ADCS R5, R17, R5 + MOVD res+0(FP), R21 + STP (R0, R1), 0(R21) + STP (R2, R3), 16(R21) + STP (R4, R5), 32(R21) + RET + +// butterfly(x, y *Element) +TEXT ·Butterfly(SB), NOSPLIT, $0-16 + LDP x+0(FP), (R25, R26) + LDP 0(R25), (R0, R1) + LDP 16(R25), (R2, R3) + LDP 32(R25), (R4, R5) + LDP 0(R26), (R6, R7) + LDP 16(R26), (R8, R9) + LDP 32(R26), (R10, R11) + ADDS R0, R6, R12 + ADCS R1, R7, R13 + ADCS R2, R8, R14 + ADCS R3, R9, R15 + ADCS R4, R10, R16 + ADCS R5, R11, R17 + + // load modulus and subtract + LDP ·qElement+0(SB), (R19, R20) + LDP ·qElement+16(SB), (R21, R22) + LDP ·qElement+32(SB), (R23, R24) + SUBS R19, R12, R19 + SBCS R20, R13, R20 + SBCS R21, R14, R21 + SBCS R22, R15, R22 + SBCS R23, R16, R23 + SBCS R24, R17, R24 + + // reduce if necessary + CSEL CS, R19, R12, R12 + CSEL CS, R20, R13, R13 + CSEL CS, R21, R14, R14 + CSEL CS, R22, R15, R15 + CSEL CS, R23, R16, R16 + CSEL CS, R24, R17, R17 + + // store + STP (R12, R13), 0(R25) + STP (R14, R15), 16(R25) + STP (R16, R17), 32(R25) + SUBS R6, R0, R6 + SBCS R7, R1, R7 + SBCS R8, R2, R8 + SBCS R9, R3, R9 + SBCS R10, R4, R10 + SBCS R11, R5, R11 + + // load modulus and select + LDP ·qElement+0(SB), (R19, R20) + LDP ·qElement+16(SB), (R21, R22) + LDP ·qElement+32(SB), (R23, R24) + CSEL CS, ZR, R19, R19 + CSEL CS, ZR, R20, R20 + CSEL CS, ZR, R21, R21 + CSEL CS, ZR, R22, R22 + CSEL CS, ZR, R23, R23 + CSEL CS, ZR, R24, R24 + + // add q if underflow, 0 if not + ADDS R6, R19, R6 + ADCS R7, R20, R7 + ADCS R8, R21, R8 + ADCS R9, R22, R9 + ADCS R10, R23, R10 + ADCS R11, R24, R11 + + // store + STP (R6, R7), 0(R26) + STP (R8, R9), 16(R26) + STP (R10, R11), 32(R26) + RET + +// mul(res, x, y *Element) +TEXT ·mul(SB), NOSPLIT, $0-24 + LDP x+8(FP), (R0, R1) + LDP 0(R0), (R3, R4) + LDP 16(R0), (R5, R6) + LDP 32(R0), (R7, R8) + LDP ·qElement+0(SB), (R16, R17) + LDP ·qElement+16(SB), (R19, R20) + LDP ·qElement+32(SB), (R21, R22) + +#define DIVSHIFT() \ + MOVD $const_qInvNeg, R2 \ + MUL R9, R2, R2 \ + MUL R16, R2, R0 \ + ADDS R0, R9, R9 \ + MUL R17, R2, R0 \ + ADCS R0, R10, R10 \ + MUL R19, R2, R0 \ + ADCS R0, R11, R11 \ + MUL R20, R2, R0 \ + ADCS R0, R12, R12 \ + MUL R21, R2, R0 \ + ADCS R0, R13, R13 \ + MUL R22, R2, R0 \ + ADCS R0, R14, R14 \ + ADCS ZR, R15, R15 \ + UMULH R16, R2, R0 \ + ADDS R0, R10, R9 \ + UMULH R17, R2, R0 \ + ADCS R0, R11, R10 \ + UMULH R19, R2, R0 \ + ADCS R0, R12, R11 \ + UMULH R20, R2, R0 \ + ADCS R0, R13, R12 \ + UMULH R21, R2, R0 \ + ADCS R0, R14, R13 \ + UMULH R22, R2, R0 \ + ADCS R0, R15, R14 \ + +#define MUL_WORD_N() \ + MUL R3, R2, R0 \ + ADDS R0, R9, R9 \ + MUL R4, R2, R0 \ + ADCS R0, R10, R10 \ + MUL R5, R2, R0 \ + ADCS R0, R11, R11 \ + MUL R6, R2, R0 \ + ADCS R0, R12, R12 \ + MUL R7, R2, R0 \ + ADCS R0, R13, R13 \ + MUL R8, R2, R0 \ + ADCS R0, R14, R14 \ + ADCS ZR, ZR, R15 \ + UMULH R3, R2, R0 \ + ADDS R0, R10, R10 \ + UMULH R4, R2, R0 \ + ADCS R0, R11, R11 \ + UMULH R5, R2, R0 \ + ADCS R0, R12, R12 \ + UMULH R6, R2, R0 \ + ADCS R0, R13, R13 \ + UMULH R7, R2, R0 \ + ADCS R0, R14, R14 \ + UMULH R8, R2, R0 \ + ADCS R0, R15, R15 \ + DIVSHIFT() \ + +#define MUL_WORD_0() \ + MUL R3, R2, R9 \ + MUL R4, R2, R10 \ + MUL R5, R2, R11 \ + MUL R6, R2, R12 \ + MUL R7, R2, R13 \ + MUL R8, R2, R14 \ + UMULH R3, R2, R0 \ + ADDS R0, R10, R10 \ + UMULH R4, R2, R0 \ + ADCS R0, R11, R11 \ + UMULH R5, R2, R0 \ + ADCS R0, R12, R12 \ + UMULH R6, R2, R0 \ + ADCS R0, R13, R13 \ + UMULH R7, R2, R0 \ + ADCS R0, R14, R14 \ + UMULH R8, R2, R0 \ + ADCS ZR, R0, R15 \ + DIVSHIFT() \ + + // mul body + MOVD 0(R1), R2 + MUL_WORD_0() + MOVD 8(R1), R2 + MUL_WORD_N() + MOVD 16(R1), R2 + MUL_WORD_N() + MOVD 24(R1), R2 + MUL_WORD_N() + MOVD 32(R1), R2 + MUL_WORD_N() + MOVD 40(R1), R2 + MUL_WORD_N() + + // reduce if necessary + SUBS R16, R9, R16 + SBCS R17, R10, R17 + SBCS R19, R11, R19 + SBCS R20, R12, R20 + SBCS R21, R13, R21 + SBCS R22, R14, R22 + CSEL CS, R16, R9, R9 + CSEL CS, R17, R10, R10 + CSEL CS, R19, R11, R11 + CSEL CS, R20, R12, R12 + CSEL CS, R21, R13, R13 + CSEL CS, R22, R14, R14 + MOVD res+0(FP), R0 + STP (R9, R10), 0(R0) + STP (R11, R12), 16(R0) + STP (R13, R14), 32(R0) + RET diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index 9b0d8d47d0..f9617f6e1d 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -264,7 +264,7 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) // asm code generation for moduli with more than 6 words can be optimized further F.ASM = F.NoCarry && F.NbWords <= 12 && F.NbWords > 1 F.ASMVector = F.ASM && F.NbWords == 4 && F.NbBits > 225 - F.ASMArm = F.ASMVector + F.ASMArm = F.ASMVector || (F.NbWords == 6) // setting Mu 2^288 / q if F.NbWords == 4 { From 6f4d15b6f6b31f4ce015d0bfa5d340b108d93dc2 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 18 Oct 2024 02:36:49 +0000 Subject: [PATCH 12/74] checkpoint --- ecc/bls12-377/fp/element_ops_arm64.s | 2 +- ecc/bls12-377/fr/element_ops_arm64.s | 2 +- ecc/bls12-381/fp/element_ops_arm64.s | 2 +- ecc/bls12-381/fr/element_ops_arm64.s | 2 +- ecc/bls24-315/fr/element_ops_arm64.s | 2 +- ecc/bls24-317/fr/element_ops_arm64.s | 2 +- ecc/bn254/fp/element_ops_arm64.s | 2 +- ecc/bn254/fr/element_ops_arm64.s | 2 +- ecc/bw6-761/fr/element_ops_arm64.s | 2 +- ecc/stark-curve/fp/element_ops_arm64.s | 2 +- ecc/stark-curve/fr/element_ops_arm64.s | 2 +- field/asm/element_4w_arm64.s | 61 ++++++++--------- field/asm/element_6w_arm64.s | 83 ++++++++++++------------ field/generator/asm/arm64/element_ops.go | 53 ++++++++++----- 14 files changed, 120 insertions(+), 99 deletions(-) diff --git a/ecc/bls12-377/fp/element_ops_arm64.s b/ecc/bls12-377/fp/element_ops_arm64.s index 8fc2d40460..32bd0163be 100644 --- a/ecc/bls12-377/fp/element_ops_arm64.s +++ b/ecc/bls12-377/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 6977593885298654654 +// We include the hash to force the Go compiler to recompile: 8872515016958290553 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/bls12-377/fr/element_ops_arm64.s b/ecc/bls12-377/fr/element_ops_arm64.s index 5843be3b0a..4a58a8a5a5 100644 --- a/ecc/bls12-377/fr/element_ops_arm64.s +++ b/ecc/bls12-377/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9556532307245233012 +// We include the hash to force the Go compiler to recompile: 5513943392596292977 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls12-381/fp/element_ops_arm64.s b/ecc/bls12-381/fp/element_ops_arm64.s index 8fc2d40460..32bd0163be 100644 --- a/ecc/bls12-381/fp/element_ops_arm64.s +++ b/ecc/bls12-381/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 6977593885298654654 +// We include the hash to force the Go compiler to recompile: 8872515016958290553 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/bls12-381/fr/element_ops_arm64.s b/ecc/bls12-381/fr/element_ops_arm64.s index 5843be3b0a..4a58a8a5a5 100644 --- a/ecc/bls12-381/fr/element_ops_arm64.s +++ b/ecc/bls12-381/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9556532307245233012 +// We include the hash to force the Go compiler to recompile: 5513943392596292977 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-315/fr/element_ops_arm64.s b/ecc/bls24-315/fr/element_ops_arm64.s index 5843be3b0a..4a58a8a5a5 100644 --- a/ecc/bls24-315/fr/element_ops_arm64.s +++ b/ecc/bls24-315/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9556532307245233012 +// We include the hash to force the Go compiler to recompile: 5513943392596292977 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-317/fr/element_ops_arm64.s b/ecc/bls24-317/fr/element_ops_arm64.s index 5843be3b0a..4a58a8a5a5 100644 --- a/ecc/bls24-317/fr/element_ops_arm64.s +++ b/ecc/bls24-317/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9556532307245233012 +// We include the hash to force the Go compiler to recompile: 5513943392596292977 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fp/element_ops_arm64.s b/ecc/bn254/fp/element_ops_arm64.s index 5843be3b0a..4a58a8a5a5 100644 --- a/ecc/bn254/fp/element_ops_arm64.s +++ b/ecc/bn254/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9556532307245233012 +// We include the hash to force the Go compiler to recompile: 5513943392596292977 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fr/element_ops_arm64.s b/ecc/bn254/fr/element_ops_arm64.s index 5843be3b0a..4a58a8a5a5 100644 --- a/ecc/bn254/fr/element_ops_arm64.s +++ b/ecc/bn254/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9556532307245233012 +// We include the hash to force the Go compiler to recompile: 5513943392596292977 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bw6-761/fr/element_ops_arm64.s b/ecc/bw6-761/fr/element_ops_arm64.s index 8fc2d40460..32bd0163be 100644 --- a/ecc/bw6-761/fr/element_ops_arm64.s +++ b/ecc/bw6-761/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 6977593885298654654 +// We include the hash to force the Go compiler to recompile: 8872515016958290553 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/stark-curve/fp/element_ops_arm64.s b/ecc/stark-curve/fp/element_ops_arm64.s index 5843be3b0a..4a58a8a5a5 100644 --- a/ecc/stark-curve/fp/element_ops_arm64.s +++ b/ecc/stark-curve/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9556532307245233012 +// We include the hash to force the Go compiler to recompile: 5513943392596292977 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fr/element_ops_arm64.s b/ecc/stark-curve/fr/element_ops_arm64.s index 5843be3b0a..4a58a8a5a5 100644 --- a/ecc/stark-curve/fr/element_ops_arm64.s +++ b/ecc/stark-curve/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9556532307245233012 +// We include the hash to force the Go compiler to recompile: 5513943392596292977 #include "../../../field/asm/element_4w_arm64.s" diff --git a/field/asm/element_4w_arm64.s b/field/asm/element_4w_arm64.s index be6a78fc3f..5ef1019bd6 100644 --- a/field/asm/element_4w_arm64.s +++ b/field/asm/element_4w_arm64.s @@ -147,43 +147,36 @@ TEXT ·Butterfly(SB), NOSPLIT, $0-16 // mul(res, x, y *Element) TEXT ·mul(SB), NOSPLIT, $0-24 - LDP x+8(FP), (R0, R1) - LDP 0(R0), (R3, R4) - LDP 16(R0), (R5, R6) - LDP ·qElement+0(SB), (R12, R13) - LDP ·qElement+16(SB), (R14, R15) - #define DIVSHIFT() \ - MOVD $const_qInvNeg, R2 \ - MUL R7, R2, R2 \ - MUL R12, R2, R0 \ - ADDS R0, R7, R7 \ - MUL R13, R2, R0 \ - ADCS R0, R8, R8 \ - MUL R14, R2, R0 \ - ADCS R0, R9, R9 \ - MUL R15, R2, R0 \ - ADCS R0, R10, R10 \ - ADCS ZR, R11, R11 \ - UMULH R12, R2, R0 \ - ADDS R0, R8, R7 \ - UMULH R13, R2, R0 \ - ADCS R0, R9, R8 \ - UMULH R14, R2, R0 \ - ADCS R0, R10, R9 \ - UMULH R15, R2, R0 \ - ADCS R0, R11, R10 \ + MUL R12, R17, R0 \ + ADDS R0, R7, R7 \ + MUL R13, R17, R0 \ + ADCS R0, R8, R8 \ + MUL R14, R17, R0 \ + ADCS R0, R9, R9 \ + MUL R15, R17, R0 \ + ADCS R0, R10, R10 \ + ADC ZR, R11, R11 \ + UMULH R12, R17, R0 \ + ADDS R0, R8, R7 \ + UMULH R13, R17, R0 \ + ADCS R0, R9, R8 \ + UMULH R14, R17, R0 \ + ADCS R0, R10, R9 \ + UMULH R15, R17, R0 \ + ADCS R0, R11, R10 \ #define MUL_WORD_N() \ MUL R3, R2, R0 \ ADDS R0, R7, R7 \ + MUL R7, R16, R17 \ MUL R4, R2, R0 \ ADCS R0, R8, R8 \ MUL R5, R2, R0 \ ADCS R0, R9, R9 \ MUL R6, R2, R0 \ ADCS R0, R10, R10 \ - ADCS ZR, ZR, R11 \ + ADC ZR, ZR, R11 \ UMULH R3, R2, R0 \ ADDS R0, R8, R8 \ UMULH R4, R2, R0 \ @@ -191,7 +184,7 @@ TEXT ·mul(SB), NOSPLIT, $0-24 UMULH R5, R2, R0 \ ADCS R0, R10, R10 \ UMULH R6, R2, R0 \ - ADCS R0, R11, R11 \ + ADC R0, R11, R11 \ DIVSHIFT() \ #define MUL_WORD_0() \ @@ -206,11 +199,19 @@ TEXT ·mul(SB), NOSPLIT, $0-24 UMULH R5, R2, R0 \ ADCS R0, R10, R10 \ UMULH R6, R2, R0 \ - ADCS ZR, R0, R11 \ + ADC R0, ZR, R11 \ + MUL R7, R16, R17 \ DIVSHIFT() \ // mul body + MOVD y+16(FP), R1 + MOVD x+8(FP), R0 + LDP 0(R0), (R3, R4) + LDP 16(R0), (R5, R6) MOVD 0(R1), R2 + MOVD $const_qInvNeg, R16 + LDP ·qElement+0(SB), (R12, R13) + LDP ·qElement+16(SB), (R14, R15) MUL_WORD_0() MOVD 8(R1), R2 MUL_WORD_N() @@ -224,11 +225,11 @@ TEXT ·mul(SB), NOSPLIT, $0-24 SBCS R13, R8, R13 SBCS R14, R9, R14 SBCS R15, R10, R15 + MOVD res+0(FP), R0 CSEL CS, R12, R7, R7 CSEL CS, R13, R8, R8 + STP (R7, R8), 0(R0) CSEL CS, R14, R9, R9 CSEL CS, R15, R10, R10 - MOVD res+0(FP), R0 - STP (R7, R8), 0(R0) STP (R9, R10), 16(R0) RET diff --git a/field/asm/element_6w_arm64.s b/field/asm/element_6w_arm64.s index 7705fc52d7..89838f6130 100644 --- a/field/asm/element_6w_arm64.s +++ b/field/asm/element_6w_arm64.s @@ -194,46 +194,37 @@ TEXT ·Butterfly(SB), NOSPLIT, $0-16 // mul(res, x, y *Element) TEXT ·mul(SB), NOSPLIT, $0-24 - LDP x+8(FP), (R0, R1) - LDP 0(R0), (R3, R4) - LDP 16(R0), (R5, R6) - LDP 32(R0), (R7, R8) - LDP ·qElement+0(SB), (R16, R17) - LDP ·qElement+16(SB), (R19, R20) - LDP ·qElement+32(SB), (R21, R22) - #define DIVSHIFT() \ - MOVD $const_qInvNeg, R2 \ - MUL R9, R2, R2 \ - MUL R16, R2, R0 \ - ADDS R0, R9, R9 \ - MUL R17, R2, R0 \ - ADCS R0, R10, R10 \ - MUL R19, R2, R0 \ - ADCS R0, R11, R11 \ - MUL R20, R2, R0 \ - ADCS R0, R12, R12 \ - MUL R21, R2, R0 \ - ADCS R0, R13, R13 \ - MUL R22, R2, R0 \ - ADCS R0, R14, R14 \ - ADCS ZR, R15, R15 \ - UMULH R16, R2, R0 \ - ADDS R0, R10, R9 \ - UMULH R17, R2, R0 \ - ADCS R0, R11, R10 \ - UMULH R19, R2, R0 \ - ADCS R0, R12, R11 \ - UMULH R20, R2, R0 \ - ADCS R0, R13, R12 \ - UMULH R21, R2, R0 \ - ADCS R0, R14, R13 \ - UMULH R22, R2, R0 \ - ADCS R0, R15, R14 \ + MUL R16, R24, R0 \ + ADDS R0, R9, R9 \ + MUL R17, R24, R0 \ + ADCS R0, R10, R10 \ + MUL R19, R24, R0 \ + ADCS R0, R11, R11 \ + MUL R20, R24, R0 \ + ADCS R0, R12, R12 \ + MUL R21, R24, R0 \ + ADCS R0, R13, R13 \ + MUL R22, R24, R0 \ + ADCS R0, R14, R14 \ + ADC ZR, R15, R15 \ + UMULH R16, R24, R0 \ + ADDS R0, R10, R9 \ + UMULH R17, R24, R0 \ + ADCS R0, R11, R10 \ + UMULH R19, R24, R0 \ + ADCS R0, R12, R11 \ + UMULH R20, R24, R0 \ + ADCS R0, R13, R12 \ + UMULH R21, R24, R0 \ + ADCS R0, R14, R13 \ + UMULH R22, R24, R0 \ + ADCS R0, R15, R14 \ #define MUL_WORD_N() \ MUL R3, R2, R0 \ ADDS R0, R9, R9 \ + MUL R9, R23, R24 \ MUL R4, R2, R0 \ ADCS R0, R10, R10 \ MUL R5, R2, R0 \ @@ -244,7 +235,7 @@ TEXT ·mul(SB), NOSPLIT, $0-24 ADCS R0, R13, R13 \ MUL R8, R2, R0 \ ADCS R0, R14, R14 \ - ADCS ZR, ZR, R15 \ + ADC ZR, ZR, R15 \ UMULH R3, R2, R0 \ ADDS R0, R10, R10 \ UMULH R4, R2, R0 \ @@ -256,7 +247,7 @@ TEXT ·mul(SB), NOSPLIT, $0-24 UMULH R7, R2, R0 \ ADCS R0, R14, R14 \ UMULH R8, R2, R0 \ - ADCS R0, R15, R15 \ + ADC R0, R15, R15 \ DIVSHIFT() \ #define MUL_WORD_0() \ @@ -277,11 +268,21 @@ TEXT ·mul(SB), NOSPLIT, $0-24 UMULH R7, R2, R0 \ ADCS R0, R14, R14 \ UMULH R8, R2, R0 \ - ADCS ZR, R0, R15 \ + ADC R0, ZR, R15 \ + MUL R9, R23, R24 \ DIVSHIFT() \ // mul body + MOVD y+16(FP), R1 + MOVD x+8(FP), R0 + LDP 0(R0), (R3, R4) + LDP 16(R0), (R5, R6) + LDP 32(R0), (R7, R8) MOVD 0(R1), R2 + MOVD $const_qInvNeg, R23 + LDP ·qElement+0(SB), (R16, R17) + LDP ·qElement+16(SB), (R19, R20) + LDP ·qElement+32(SB), (R21, R22) MUL_WORD_0() MOVD 8(R1), R2 MUL_WORD_N() @@ -301,14 +302,14 @@ TEXT ·mul(SB), NOSPLIT, $0-24 SBCS R20, R12, R20 SBCS R21, R13, R21 SBCS R22, R14, R22 + MOVD res+0(FP), R0 CSEL CS, R16, R9, R9 CSEL CS, R17, R10, R10 + STP (R9, R10), 0(R0) CSEL CS, R19, R11, R11 CSEL CS, R20, R12, R12 + STP (R11, R12), 16(R0) CSEL CS, R21, R13, R13 CSEL CS, R22, R14, R14 - MOVD res+0(FP), R0 - STP (R9, R10), 0(R0) - STP (R11, R12), 16(R0) STP (R13, R14), 32(R0) RET diff --git a/field/generator/asm/arm64/element_ops.go b/field/generator/asm/arm64/element_ops.go index f865804a9a..29d7279138 100644 --- a/field/generator/asm/arm64/element_ops.go +++ b/field/generator/asm/arm64/element_ops.go @@ -200,20 +200,13 @@ func (f *FFArm64) generateMul() { t := registers.PopN(f.NbWords + 1) q := registers.PopN(f.NbWords) - f.LDP("x+8(FP)", xPtr, yPtr) - - f.load(xPtr, a) ax := xPtr - // f.load(yPtr, y) - - for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(f.qAt(i), q[i], q[i+1]) - } + qInv0 := registers.Pop() + m := registers.Pop() divShift := f.Define("divShift", 0, func(args ...arm64.Register) { - m := bi - f.MOVD(f.qInv0(), m) - f.MUL(t[0], m, m) + // m := bi + // f.MUL(t[0], qInv0, m) // for j=0 to N-1 // (C,t[j-1]) := t[j] + m*q[j] + C @@ -226,7 +219,7 @@ func (f *FFArm64) generateMul() { f.ADCS(ax, t[j], t[j]) } } - f.ADCS("ZR", t[f.NbWords], t[f.NbWords]) + f.ADC("ZR", t[f.NbWords], t[f.NbWords]) // propagate high bits f.UMULH(q[0], m, ax) @@ -252,20 +245,27 @@ func (f *FFArm64) generateMul() { if j == 0 { f.ADDS(ax, t[j], t[j]) + f.MUL(t[0], qInv0, m) } else { f.ADCS(ax, t[j], t[j]) } } - f.ADCS("ZR", "ZR", t[f.NbWords]) + f.ADC("ZR", "ZR", t[f.NbWords]) // propagate high bits f.UMULH(a[0], bi, ax) for j := 1; j <= f.NbWords; j++ { if j == 1 { f.ADDS(ax, t[j], t[j]) + } else { - f.ADCS(ax, t[j], t[j]) + if j == f.NbWords { + + f.ADC(ax, t[j], t[j]) + } else { + f.ADCS(ax, t[j], t[j]) + } } if j != f.NbWords { f.UMULH(a[j], bi, ax) @@ -285,12 +285,15 @@ func (f *FFArm64) generateMul() { // propagate high bits f.UMULH(a[0], bi, ax) + for j := 1; j <= f.NbWords; j++ { if j == 1 { f.ADDS(ax, t[j], t[j]) + } else { if j == f.NbWords { - f.ADCS("ZR", ax, t[j]) + + f.ADC(ax, "ZR", t[j]) } else { f.ADCS(ax, t[j], t[j]) } @@ -299,31 +302,47 @@ func (f *FFArm64) generateMul() { f.UMULH(a[j], bi, ax) } } + f.MUL(t[0], qInv0, m) divShift() }) f.Comment("mul body") + // f.LDP("x+8(FP)", xPtr, yPtr) + f.MOVD("y+16(FP)", yPtr) + f.MOVD("x+8(FP)", xPtr) + f.load(xPtr, a) for i := 0; i < f.NbWords; i++ { f.MOVD(yPtr.At(i), bi) if i == 0 { + f.MOVD(f.qInv0(), qInv0) + + for i := 0; i < f.NbWords-1; i += 2 { + f.LDP(f.qAt(i), q[i], q[i+1]) + } + mulWord0() } else { mulWordN() } } + f.Comment("reduce if necessary") f.SUBS(q[0], t[0], q[0]) for i := 1; i < f.NbWords; i++ { f.SBCS(q[i], t[i], q[i]) } + + f.MOVD("res+0(FP)", ax) for i := 0; i < f.NbWords; i++ { f.CSEL("CS", q[i], t[i], t[i]) + if i%2 == 1 { + f.STP(t[i-1], t[i], ax.At(i-1)) + } } - f.MOVD("res+0(FP)", xPtr) - f.store(t[:f.NbWords], xPtr) + // f.store(t[:f.NbWords], resPtr) f.RET() } From 8bc826763f1c51192533175f27538ca6c4fb7e8c Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 18 Oct 2024 02:54:39 +0000 Subject: [PATCH 13/74] checkpoint --- ecc/bls12-377/fp/element_ops_arm64.s | 2 +- ecc/bls12-377/fr/element_ops_arm64.s | 2 +- ecc/bls12-381/fp/element_ops_arm64.s | 2 +- ecc/bls12-381/fr/element_ops_arm64.s | 2 +- ecc/bls24-315/fr/element_ops_arm64.s | 2 +- ecc/bls24-317/fr/element_ops_arm64.s | 2 +- ecc/bn254/fp/element_ops_arm64.s | 2 +- ecc/bn254/fr/element_ops_arm64.s | 2 +- ecc/bw6-761/fr/element_ops_arm64.s | 2 +- ecc/stark-curve/fp/element_ops_arm64.s | 2 +- ecc/stark-curve/fr/element_ops_arm64.s | 2 +- field/asm/element_4w_arm64.s | 52 ++++++++--------- field/asm/element_6w_arm64.s | 71 +++++++++++------------- field/generator/asm/arm64/element_ops.go | 62 +++++++++++++++++---- 14 files changed, 116 insertions(+), 91 deletions(-) diff --git a/ecc/bls12-377/fp/element_ops_arm64.s b/ecc/bls12-377/fp/element_ops_arm64.s index 32bd0163be..f4a10bcc1e 100644 --- a/ecc/bls12-377/fp/element_ops_arm64.s +++ b/ecc/bls12-377/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 8872515016958290553 +// We include the hash to force the Go compiler to recompile: 12074667263144223547 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/bls12-377/fr/element_ops_arm64.s b/ecc/bls12-377/fr/element_ops_arm64.s index 4a58a8a5a5..4327e8acf9 100644 --- a/ecc/bls12-377/fr/element_ops_arm64.s +++ b/ecc/bls12-377/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 5513943392596292977 +// We include the hash to force the Go compiler to recompile: 15593824621840630566 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls12-381/fp/element_ops_arm64.s b/ecc/bls12-381/fp/element_ops_arm64.s index 32bd0163be..f4a10bcc1e 100644 --- a/ecc/bls12-381/fp/element_ops_arm64.s +++ b/ecc/bls12-381/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 8872515016958290553 +// We include the hash to force the Go compiler to recompile: 12074667263144223547 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/bls12-381/fr/element_ops_arm64.s b/ecc/bls12-381/fr/element_ops_arm64.s index 4a58a8a5a5..4327e8acf9 100644 --- a/ecc/bls12-381/fr/element_ops_arm64.s +++ b/ecc/bls12-381/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 5513943392596292977 +// We include the hash to force the Go compiler to recompile: 15593824621840630566 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-315/fr/element_ops_arm64.s b/ecc/bls24-315/fr/element_ops_arm64.s index 4a58a8a5a5..4327e8acf9 100644 --- a/ecc/bls24-315/fr/element_ops_arm64.s +++ b/ecc/bls24-315/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 5513943392596292977 +// We include the hash to force the Go compiler to recompile: 15593824621840630566 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-317/fr/element_ops_arm64.s b/ecc/bls24-317/fr/element_ops_arm64.s index 4a58a8a5a5..4327e8acf9 100644 --- a/ecc/bls24-317/fr/element_ops_arm64.s +++ b/ecc/bls24-317/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 5513943392596292977 +// We include the hash to force the Go compiler to recompile: 15593824621840630566 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fp/element_ops_arm64.s b/ecc/bn254/fp/element_ops_arm64.s index 4a58a8a5a5..4327e8acf9 100644 --- a/ecc/bn254/fp/element_ops_arm64.s +++ b/ecc/bn254/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 5513943392596292977 +// We include the hash to force the Go compiler to recompile: 15593824621840630566 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fr/element_ops_arm64.s b/ecc/bn254/fr/element_ops_arm64.s index 4a58a8a5a5..4327e8acf9 100644 --- a/ecc/bn254/fr/element_ops_arm64.s +++ b/ecc/bn254/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 5513943392596292977 +// We include the hash to force the Go compiler to recompile: 15593824621840630566 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bw6-761/fr/element_ops_arm64.s b/ecc/bw6-761/fr/element_ops_arm64.s index 32bd0163be..f4a10bcc1e 100644 --- a/ecc/bw6-761/fr/element_ops_arm64.s +++ b/ecc/bw6-761/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 8872515016958290553 +// We include the hash to force the Go compiler to recompile: 12074667263144223547 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/stark-curve/fp/element_ops_arm64.s b/ecc/stark-curve/fp/element_ops_arm64.s index 4a58a8a5a5..4327e8acf9 100644 --- a/ecc/stark-curve/fp/element_ops_arm64.s +++ b/ecc/stark-curve/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 5513943392596292977 +// We include the hash to force the Go compiler to recompile: 15593824621840630566 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fr/element_ops_arm64.s b/ecc/stark-curve/fr/element_ops_arm64.s index 4a58a8a5a5..4327e8acf9 100644 --- a/ecc/stark-curve/fr/element_ops_arm64.s +++ b/ecc/stark-curve/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 5513943392596292977 +// We include the hash to force the Go compiler to recompile: 15593824621840630566 #include "../../../field/asm/element_4w_arm64.s" diff --git a/field/asm/element_4w_arm64.s b/field/asm/element_4w_arm64.s index 5ef1019bd6..aea8a07677 100644 --- a/field/asm/element_4w_arm64.s +++ b/field/asm/element_4w_arm64.s @@ -102,47 +102,41 @@ TEXT ·Butterfly(SB), NOSPLIT, $0-16 ADDS R0, R4, R8 ADCS R1, R5, R9 ADCS R2, R6, R10 - ADCS R3, R7, R11 - - // load modulus and subtract - LDP ·qElement+0(SB), (R12, R13) - LDP ·qElement+16(SB), (R14, R15) - SUBS R12, R8, R12 - SBCS R13, R9, R13 - SBCS R14, R10, R14 - SBCS R15, R11, R15 - - // reduce if necessary - CSEL CS, R12, R8, R8 - CSEL CS, R13, R9, R9 - CSEL CS, R14, R10, R10 - CSEL CS, R15, R11, R11 - - // store - STP (R8, R9), 0(R16) - STP (R10, R11), 16(R16) + ADC R3, R7, R11 SUBS R4, R0, R4 SBCS R5, R1, R5 SBCS R6, R2, R6 SBCS R7, R3, R7 // load modulus and select - LDP ·qElement+0(SB), (R12, R13) - LDP ·qElement+16(SB), (R14, R15) - CSEL CS, ZR, R12, R12 - CSEL CS, ZR, R13, R13 - CSEL CS, ZR, R14, R14 - CSEL CS, ZR, R15, R15 + LDP ·qElement+0(SB), (R0, R1) + CSEL CS, ZR, R0, R12 + CSEL CS, ZR, R1, R13 + LDP ·qElement+16(SB), (R2, R3) + CSEL CS, ZR, R2, R14 + CSEL CS, ZR, R3, R15 // add q if underflow, 0 if not ADDS R4, R12, R4 ADCS R5, R13, R5 + STP (R4, R5), 0(R17) ADCS R6, R14, R6 - ADCS R7, R15, R7 + ADC R7, R15, R7 + STP (R6, R7), 16(R17) - // store - STP (R4, R5), 0(R17) - STP (R6, R7), 16(R17) + // load modulus and subtract + SUBS R0, R8, R0 + SBCS R1, R9, R1 + SBCS R2, R10, R2 + SBCS R3, R11, R3 + + // reduce if necessary + CSEL CS, R0, R8, R8 + CSEL CS, R1, R9, R9 + STP (R8, R9), 0(R16) + CSEL CS, R2, R10, R10 + CSEL CS, R3, R11, R11 + STP (R10, R11), 16(R16) RET // mul(res, x, y *Element) diff --git a/field/asm/element_6w_arm64.s b/field/asm/element_6w_arm64.s index 89838f6130..7693ba1451 100644 --- a/field/asm/element_6w_arm64.s +++ b/field/asm/element_6w_arm64.s @@ -135,31 +135,7 @@ TEXT ·Butterfly(SB), NOSPLIT, $0-16 ADCS R2, R8, R14 ADCS R3, R9, R15 ADCS R4, R10, R16 - ADCS R5, R11, R17 - - // load modulus and subtract - LDP ·qElement+0(SB), (R19, R20) - LDP ·qElement+16(SB), (R21, R22) - LDP ·qElement+32(SB), (R23, R24) - SUBS R19, R12, R19 - SBCS R20, R13, R20 - SBCS R21, R14, R21 - SBCS R22, R15, R22 - SBCS R23, R16, R23 - SBCS R24, R17, R24 - - // reduce if necessary - CSEL CS, R19, R12, R12 - CSEL CS, R20, R13, R13 - CSEL CS, R21, R14, R14 - CSEL CS, R22, R15, R15 - CSEL CS, R23, R16, R16 - CSEL CS, R24, R17, R17 - - // store - STP (R12, R13), 0(R25) - STP (R14, R15), 16(R25) - STP (R16, R17), 32(R25) + ADC R5, R11, R17 SUBS R6, R0, R6 SBCS R7, R1, R7 SBCS R8, R2, R8 @@ -168,28 +144,45 @@ TEXT ·Butterfly(SB), NOSPLIT, $0-16 SBCS R11, R5, R11 // load modulus and select - LDP ·qElement+0(SB), (R19, R20) - LDP ·qElement+16(SB), (R21, R22) - LDP ·qElement+32(SB), (R23, R24) - CSEL CS, ZR, R19, R19 - CSEL CS, ZR, R20, R20 - CSEL CS, ZR, R21, R21 - CSEL CS, ZR, R22, R22 - CSEL CS, ZR, R23, R23 - CSEL CS, ZR, R24, R24 + LDP ·qElement+0(SB), (R0, R1) + CSEL CS, ZR, R0, R19 + CSEL CS, ZR, R1, R20 + LDP ·qElement+16(SB), (R2, R3) + CSEL CS, ZR, R2, R21 + CSEL CS, ZR, R3, R22 + LDP ·qElement+32(SB), (R4, R5) + CSEL CS, ZR, R4, R23 + CSEL CS, ZR, R5, R24 // add q if underflow, 0 if not ADDS R6, R19, R6 ADCS R7, R20, R7 + STP (R6, R7), 0(R26) ADCS R8, R21, R8 ADCS R9, R22, R9 + STP (R8, R9), 16(R26) ADCS R10, R23, R10 - ADCS R11, R24, R11 + ADC R11, R24, R11 + STP (R10, R11), 32(R26) - // store - STP (R6, R7), 0(R26) - STP (R8, R9), 16(R26) - STP (R10, R11), 32(R26) + // load modulus and subtract + SUBS R0, R12, R0 + SBCS R1, R13, R1 + SBCS R2, R14, R2 + SBCS R3, R15, R3 + SBCS R4, R16, R4 + SBCS R5, R17, R5 + + // reduce if necessary + CSEL CS, R0, R12, R12 + CSEL CS, R1, R13, R13 + STP (R12, R13), 0(R25) + CSEL CS, R2, R14, R14 + CSEL CS, R3, R15, R15 + STP (R14, R15), 16(R25) + CSEL CS, R4, R16, R16 + CSEL CS, R5, R17, R17 + STP (R16, R17), 32(R25) RET // mul(res, x, y *Element) diff --git a/field/generator/asm/arm64/element_ops.go b/field/generator/asm/arm64/element_ops.go index 29d7279138..4f1ba16ea8 100644 --- a/field/generator/asm/arm64/element_ops.go +++ b/field/generator/asm/arm64/element_ops.go @@ -149,14 +149,19 @@ func (f *FFArm64) generateButterfly() { f.ADDS(a[0], b[0], aRes[0]) for i := 1; i < f.NbWords; i++ { - f.ADCS(a[i], b[i], aRes[i]) + + if i == f.NbWordsLastIndex { + f.ADC(a[i], b[i], aRes[i]) + } else { + f.ADCS(a[i], b[i], aRes[i]) + } } - f.reduce(aRes, t) + // f.reduce(aRes, t) - f.Comment("store") + // f.Comment("store") - f.store(aRes, aPtr) + // f.store(aRes, aPtr) bRes := b @@ -169,21 +174,29 @@ func (f *FFArm64) generateButterfly() { zero := arm64.Register("ZR") - for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(f.qAt(i), t[i], t[i+1]) - } + // for i := 0; i < f.NbWords-1; i += 2 { + // f.LDP(f.qAt(i), t[i], t[i+1]) + // } for i := 0; i < f.NbWords; i++ { - f.CSEL("CS", zero, t[i], t[i]) + if i%2 == 0 { + f.LDP(f.qAt(i), a[i], a[i+1]) + } + f.CSEL("CS", zero, a[i], t[i]) } f.Comment("add q if underflow, 0 if not") f.ADDS(bRes[0], t[0], bRes[0]) for i := 1; i < f.NbWords; i++ { - f.ADCS(bRes[i], t[i], bRes[i]) + if i == f.NbWordsLastIndex { + f.ADC(bRes[i], t[i], bRes[i]) + } else { + f.ADCS(bRes[i], t[i], bRes[i]) + } + if i%2 == 1 { + f.STP(bRes[i-1], bRes[i], bPtr.At(i-1)) + } } - f.Comment("store") - - f.store(bRes, bPtr) + f.reduceAndStore(aRes, a, aPtr) f.RET() } @@ -380,3 +393,28 @@ func (f *FFArm64) store(z []arm64.Register, zPtr arm64.Register) { f.STP(z[i], z[i+1], zPtr.At(i)) } } + +func (f *FFArm64) reduceAndStore(t, q []arm64.Register, zPtr arm64.Register) { + + if len(t) != f.NbWords || len(q) != f.NbWords { + panic("need 2*nbWords registers") + } + + f.Comment("load modulus and subtract") + + // for i := 0; i < f.NbWords-1; i += 2 { + // f.LDP(f.qAt(i), q[i], q[i+1]) + // } + f.SUBS(q[0], t[0], q[0]) + for i := 1; i < f.NbWords; i++ { + f.SBCS(q[i], t[i], q[i]) + } + + f.Comment("reduce if necessary") + for i := 0; i < f.NbWords; i++ { + f.CSEL("CS", q[i], t[i], t[i]) + if i%2 == 1 { + f.STP(t[i-1], t[i], zPtr.At(i-1)) + } + } +} From 13521a97dec2de462f0403b3199540b68b3b5247 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sat, 19 Oct 2024 10:04:43 -0500 Subject: [PATCH 14/74] checkpoint --- ecc/bls12-377/fp/element_ops_arm64.s | 2 +- ecc/bls12-377/fr/element_ops_arm64.s | 2 +- ecc/bls12-381/fp/element_ops_arm64.s | 2 +- ecc/bls12-381/fr/element_ops_arm64.s | 2 +- ecc/bls24-315/fr/element_ops_arm64.s | 2 +- ecc/bls24-317/fr/element_ops_arm64.s | 2 +- ecc/bn254/fp/element_ops_arm64.s | 2 +- ecc/bn254/fr/element_ops_arm64.s | 2 +- ecc/bw6-761/fr/element_ops_arm64.s | 2 +- ecc/stark-curve/fp/element_ops_arm64.s | 2 +- ecc/stark-curve/fr/element_ops_arm64.s | 2 +- field/asm/element_4w_arm64.s | 20 +++++++--------- field/asm/element_6w_arm64.s | 24 +++++++++---------- field/generator/asm/arm64/element_ops.go | 11 ++++----- field/generator/generator.go | 2 ++ .../internal/templates/element/base.go | 1 + .../internal/templates/element/ops_asm.go | 1 + .../internal/templates/element/tests.go | 1 + go.mod | 2 +- go.sum | 2 ++ 20 files changed, 44 insertions(+), 42 deletions(-) diff --git a/ecc/bls12-377/fp/element_ops_arm64.s b/ecc/bls12-377/fp/element_ops_arm64.s index f4a10bcc1e..c4c5981603 100644 --- a/ecc/bls12-377/fp/element_ops_arm64.s +++ b/ecc/bls12-377/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 12074667263144223547 +// We include the hash to force the Go compiler to recompile: 635053461911122795 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/bls12-377/fr/element_ops_arm64.s b/ecc/bls12-377/fr/element_ops_arm64.s index 4327e8acf9..1501fe56d2 100644 --- a/ecc/bls12-377/fr/element_ops_arm64.s +++ b/ecc/bls12-377/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 15593824621840630566 +// We include the hash to force the Go compiler to recompile: 11331350010912976978 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls12-381/fp/element_ops_arm64.s b/ecc/bls12-381/fp/element_ops_arm64.s index f4a10bcc1e..c4c5981603 100644 --- a/ecc/bls12-381/fp/element_ops_arm64.s +++ b/ecc/bls12-381/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 12074667263144223547 +// We include the hash to force the Go compiler to recompile: 635053461911122795 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/bls12-381/fr/element_ops_arm64.s b/ecc/bls12-381/fr/element_ops_arm64.s index 4327e8acf9..1501fe56d2 100644 --- a/ecc/bls12-381/fr/element_ops_arm64.s +++ b/ecc/bls12-381/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 15593824621840630566 +// We include the hash to force the Go compiler to recompile: 11331350010912976978 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-315/fr/element_ops_arm64.s b/ecc/bls24-315/fr/element_ops_arm64.s index 4327e8acf9..1501fe56d2 100644 --- a/ecc/bls24-315/fr/element_ops_arm64.s +++ b/ecc/bls24-315/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 15593824621840630566 +// We include the hash to force the Go compiler to recompile: 11331350010912976978 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-317/fr/element_ops_arm64.s b/ecc/bls24-317/fr/element_ops_arm64.s index 4327e8acf9..1501fe56d2 100644 --- a/ecc/bls24-317/fr/element_ops_arm64.s +++ b/ecc/bls24-317/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 15593824621840630566 +// We include the hash to force the Go compiler to recompile: 11331350010912976978 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fp/element_ops_arm64.s b/ecc/bn254/fp/element_ops_arm64.s index 4327e8acf9..1501fe56d2 100644 --- a/ecc/bn254/fp/element_ops_arm64.s +++ b/ecc/bn254/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 15593824621840630566 +// We include the hash to force the Go compiler to recompile: 11331350010912976978 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fr/element_ops_arm64.s b/ecc/bn254/fr/element_ops_arm64.s index 4327e8acf9..1501fe56d2 100644 --- a/ecc/bn254/fr/element_ops_arm64.s +++ b/ecc/bn254/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 15593824621840630566 +// We include the hash to force the Go compiler to recompile: 11331350010912976978 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bw6-761/fr/element_ops_arm64.s b/ecc/bw6-761/fr/element_ops_arm64.s index f4a10bcc1e..c4c5981603 100644 --- a/ecc/bw6-761/fr/element_ops_arm64.s +++ b/ecc/bw6-761/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 12074667263144223547 +// We include the hash to force the Go compiler to recompile: 635053461911122795 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/stark-curve/fp/element_ops_arm64.s b/ecc/stark-curve/fp/element_ops_arm64.s index 4327e8acf9..1501fe56d2 100644 --- a/ecc/stark-curve/fp/element_ops_arm64.s +++ b/ecc/stark-curve/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 15593824621840630566 +// We include the hash to force the Go compiler to recompile: 11331350010912976978 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fr/element_ops_arm64.s b/ecc/stark-curve/fr/element_ops_arm64.s index 4327e8acf9..1501fe56d2 100644 --- a/ecc/stark-curve/fr/element_ops_arm64.s +++ b/ecc/stark-curve/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 15593824621840630566 +// We include the hash to force the Go compiler to recompile: 11331350010912976978 #include "../../../field/asm/element_4w_arm64.s" diff --git a/field/asm/element_4w_arm64.s b/field/asm/element_4w_arm64.s index aea8a07677..91a34e057b 100644 --- a/field/asm/element_4w_arm64.s +++ b/field/asm/element_4w_arm64.s @@ -4,20 +4,21 @@ #include "go_asm.h" // add(res, x, y *Element) -TEXT ·add(SB), NOSPLIT, $0-24 +TEXT ·add(SB), NOFRAME|NOSPLIT, $0-24 LDP x+8(FP), (R12, R13) LDP 0(R12), (R8, R9) LDP 16(R12), (R10, R11) LDP 0(R13), (R4, R5) LDP 16(R13), (R6, R7) + LDP ·qElement+0(SB), (R0, R1) + LDP ·qElement+16(SB), (R2, R3) ADDS R8, R4, R4 ADCS R9, R5, R5 ADCS R10, R6, R6 ADCS R11, R7, R7 + MOVD res+0(FP), R14 // load modulus and subtract - LDP ·qElement+0(SB), (R0, R1) - LDP ·qElement+16(SB), (R2, R3) SUBS R0, R4, R0 SBCS R1, R5, R1 SBCS R2, R6, R2 @@ -26,17 +27,14 @@ TEXT ·add(SB), NOSPLIT, $0-24 // reduce if necessary CSEL CS, R0, R4, R4 CSEL CS, R1, R5, R5 + STP (R4, R5), 0(R14) CSEL CS, R2, R6, R6 CSEL CS, R3, R7, R7 - - // store - MOVD res+0(FP), R14 - STP (R4, R5), 0(R14) STP (R6, R7), 16(R14) RET // double(res, x *Element) -TEXT ·double(SB), NOSPLIT, $0-16 +TEXT ·double(SB), NOFRAME|NOSPLIT, $0-16 LDP res+0(FP), (R1, R0) LDP 0(R0), (R2, R3) LDP 16(R0), (R4, R5) @@ -63,7 +61,7 @@ TEXT ·double(SB), NOSPLIT, $0-16 RET // sub(res, x, y *Element) -TEXT ·sub(SB), NOSPLIT, $0-24 +TEXT ·sub(SB), NOFRAME|NOSPLIT, $0-24 LDP x+8(FP), (R12, R13) LDP 0(R12), (R4, R5) LDP 16(R12), (R6, R7) @@ -93,7 +91,7 @@ TEXT ·sub(SB), NOSPLIT, $0-24 RET // butterfly(x, y *Element) -TEXT ·Butterfly(SB), NOSPLIT, $0-16 +TEXT ·Butterfly(SB), NOFRAME|NOSPLIT, $0-16 LDP x+0(FP), (R16, R17) LDP 0(R16), (R0, R1) LDP 16(R16), (R2, R3) @@ -140,7 +138,7 @@ TEXT ·Butterfly(SB), NOSPLIT, $0-16 RET // mul(res, x, y *Element) -TEXT ·mul(SB), NOSPLIT, $0-24 +TEXT ·mul(SB), NOFRAME|NOSPLIT, $0-24 #define DIVSHIFT() \ MUL R12, R17, R0 \ ADDS R0, R7, R7 \ diff --git a/field/asm/element_6w_arm64.s b/field/asm/element_6w_arm64.s index 7693ba1451..c4a5ffd7bb 100644 --- a/field/asm/element_6w_arm64.s +++ b/field/asm/element_6w_arm64.s @@ -4,7 +4,7 @@ #include "go_asm.h" // add(res, x, y *Element) -TEXT ·add(SB), NOSPLIT, $0-24 +TEXT ·add(SB), NOFRAME|NOSPLIT, $0-24 LDP x+8(FP), (R19, R20) LDP 0(R19), (R12, R13) LDP 16(R19), (R14, R15) @@ -12,17 +12,18 @@ TEXT ·add(SB), NOSPLIT, $0-24 LDP 0(R20), (R6, R7) LDP 16(R20), (R8, R9) LDP 32(R20), (R10, R11) + LDP ·qElement+0(SB), (R0, R1) + LDP ·qElement+16(SB), (R2, R3) + LDP ·qElement+32(SB), (R4, R5) ADDS R12, R6, R6 ADCS R13, R7, R7 ADCS R14, R8, R8 ADCS R15, R9, R9 ADCS R16, R10, R10 ADCS R17, R11, R11 + MOVD res+0(FP), R21 // load modulus and subtract - LDP ·qElement+0(SB), (R0, R1) - LDP ·qElement+16(SB), (R2, R3) - LDP ·qElement+32(SB), (R4, R5) SUBS R0, R6, R0 SBCS R1, R7, R1 SBCS R2, R8, R2 @@ -33,20 +34,17 @@ TEXT ·add(SB), NOSPLIT, $0-24 // reduce if necessary CSEL CS, R0, R6, R6 CSEL CS, R1, R7, R7 + STP (R6, R7), 0(R21) CSEL CS, R2, R8, R8 CSEL CS, R3, R9, R9 + STP (R8, R9), 16(R21) CSEL CS, R4, R10, R10 CSEL CS, R5, R11, R11 - - // store - MOVD res+0(FP), R21 - STP (R6, R7), 0(R21) - STP (R8, R9), 16(R21) STP (R10, R11), 32(R21) RET // double(res, x *Element) -TEXT ·double(SB), NOSPLIT, $0-16 +TEXT ·double(SB), NOFRAME|NOSPLIT, $0-16 LDP res+0(FP), (R1, R0) LDP 0(R0), (R2, R3) LDP 16(R0), (R4, R5) @@ -82,7 +80,7 @@ TEXT ·double(SB), NOSPLIT, $0-16 RET // sub(res, x, y *Element) -TEXT ·sub(SB), NOSPLIT, $0-24 +TEXT ·sub(SB), NOFRAME|NOSPLIT, $0-24 LDP x+8(FP), (R19, R20) LDP 0(R19), (R6, R7) LDP 16(R19), (R8, R9) @@ -122,7 +120,7 @@ TEXT ·sub(SB), NOSPLIT, $0-24 RET // butterfly(x, y *Element) -TEXT ·Butterfly(SB), NOSPLIT, $0-16 +TEXT ·Butterfly(SB), NOFRAME|NOSPLIT, $0-16 LDP x+0(FP), (R25, R26) LDP 0(R25), (R0, R1) LDP 16(R25), (R2, R3) @@ -186,7 +184,7 @@ TEXT ·Butterfly(SB), NOSPLIT, $0-16 RET // mul(res, x, y *Element) -TEXT ·mul(SB), NOSPLIT, $0-24 +TEXT ·mul(SB), NOFRAME|NOSPLIT, $0-24 #define DIVSHIFT() \ MUL R16, R24, R0 \ ADDS R0, R9, R9 \ diff --git a/field/generator/asm/arm64/element_ops.go b/field/generator/asm/arm64/element_ops.go index 4f1ba16ea8..14415944be 100644 --- a/field/generator/asm/arm64/element_ops.go +++ b/field/generator/asm/arm64/element_ops.go @@ -24,7 +24,7 @@ func (f *FFArm64) generateAdd() { defer f.AssertCleanStack(0, 0) // registers - t := registers.PopN(f.NbWords) + q := registers.PopN(f.NbWords) z := registers.PopN(f.NbWords) x := registers.PopN(f.NbWords) xPtr := registers.Pop() @@ -35,18 +35,17 @@ func (f *FFArm64) generateAdd() { f.load(xPtr, x) f.load(yPtr, z) + for i := 0; i < f.NbWords-1; i += 2 { + f.LDP(f.qAt(i), q[i], q[i+1]) + } f.ADDS(x[0], z[0], z[0]) for i := 1; i < f.NbWords; i++ { f.ADCS(x[i], z[i], z[i]) } - f.reduce(z, t) - - f.Comment("store") - f.MOVD("res+0(FP)", zPtr) - f.store(z, zPtr) + f.reduceAndStore(z, q, zPtr) f.RET() diff --git a/field/generator/generator.go b/field/generator/generator.go index e9c86dfdd7..647e20d225 100644 --- a/field/generator/generator.go +++ b/field/generator/generator.go @@ -210,6 +210,8 @@ func GenerateFF(F *config.FieldConfig, outputDir, asmDirBuildPath, asmDirInclude src := []string{ element.MulDoc, element.OpsARM64, + element.MulNoCarry, + element.Reduce, } pathSrc := filepath.Join(outputDir, eName+"_ops_arm64.go") bavardOptsCpy := make([]func(*bavard.Bavard) error, len(bavardOpts)) diff --git a/field/generator/internal/templates/element/base.go b/field/generator/internal/templates/element/base.go index 65ccb7a68b..fae184cd11 100644 --- a/field/generator/internal/templates/element/base.go +++ b/field/generator/internal/templates/element/base.go @@ -397,6 +397,7 @@ func (z *{{.ElementName}}) fromMont() *{{.ElementName}} { } {{- if not .ASMArm}} + // Add z = x + y (mod q) func (z *{{.ElementName}}) Add( x, y *{{.ElementName}}) *{{.ElementName}} { {{ $hasCarry := or (not $.NoCarry) (gt $.NbWords 1)}} diff --git a/field/generator/internal/templates/element/ops_asm.go b/field/generator/internal/templates/element/ops_asm.go index 67a7836c65..78472adfa5 100644 --- a/field/generator/internal/templates/element/ops_asm.go +++ b/field/generator/internal/templates/element/ops_asm.go @@ -199,6 +199,7 @@ func (z *{{.ElementName}}) Square(x *{{.ElementName}}) *{{.ElementName}} { const OpsARM64 = ` {{if .ASMArm}} + //go:noescape func add(res,x,y *{{.ElementName}}) diff --git a/field/generator/internal/templates/element/tests.go b/field/generator/internal/templates/element/tests.go index d09dbeb7e5..35d777e493 100644 --- a/field/generator/internal/templates/element/tests.go +++ b/field/generator/internal/templates/element/tests.go @@ -207,6 +207,7 @@ func Benchmark{{toTitle .ElementName}}Mul(b *testing.B) { } } + func Benchmark{{toTitle .ElementName}}Cmp(b *testing.B) { x := {{.ElementName}}{ {{- range $i := .RSquare}} diff --git a/go.mod b/go.mod index 9006ce5c13..e5a21a2e84 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/bits-and-blooms/bitset v1.14.2 - github.com/consensys/bavard v0.1.23-0.20241015221109-a56d5bf777eb + github.com/consensys/bavard v0.1.23-0.20241019150039-28659c2eb91c github.com/leanovate/gopter v0.2.11 github.com/mmcloughlin/addchain v0.4.0 github.com/spf13/cobra v1.8.1 diff --git a/go.sum b/go.sum index 1164581955..026f2ec62e 100644 --- a/go.sum +++ b/go.sum @@ -57,6 +57,8 @@ github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnht github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/consensys/bavard v0.1.23-0.20241015221109-a56d5bf777eb h1:yPPmCz5FvvKMAKz/O7t5qJNZcEA0q6ermJzoL2D0oQU= github.com/consensys/bavard v0.1.23-0.20241015221109-a56d5bf777eb/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= +github.com/consensys/bavard v0.1.23-0.20241019150039-28659c2eb91c h1:sK5i7h6ZVAj2eK7Vt5CzSnenlsxp828qvga+X5TjSVM= +github.com/consensys/bavard v0.1.23-0.20241019150039-28659c2eb91c/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= From a0e023e409bd44fcac783b7df469f026a24fc275 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sat, 19 Oct 2024 15:47:04 +0000 Subject: [PATCH 15/74] code cleaning --- ecc/bls12-377/fp/element.go | 69 ++++ ecc/bls12-377/fp/element_ops_arm64.go | 24 -- ecc/bls12-377/fp/element_ops_arm64.s | 2 +- ecc/bls12-377/fr/element.go | 57 ++++ ecc/bls12-377/fr/element_ops_arm64.go | 24 -- ecc/bls12-377/fr/element_ops_arm64.s | 2 +- ecc/bls12-381/fp/element.go | 69 ++++ ecc/bls12-381/fp/element_ops_arm64.go | 24 -- ecc/bls12-381/fp/element_ops_arm64.s | 2 +- ecc/bls12-381/fr/element.go | 57 ++++ ecc/bls12-381/fr/element_ops_arm64.go | 24 -- ecc/bls12-381/fr/element_ops_arm64.s | 2 +- ecc/bls24-315/fr/element.go | 57 ++++ ecc/bls24-315/fr/element_ops_arm64.go | 24 -- ecc/bls24-315/fr/element_ops_arm64.s | 2 +- ecc/bls24-317/fr/element.go | 57 ++++ ecc/bls24-317/fr/element_ops_arm64.go | 24 -- ecc/bls24-317/fr/element_ops_arm64.s | 2 +- ecc/bn254/fp/element.go | 57 ++++ ecc/bn254/fp/element_ops_arm64.go | 24 -- ecc/bn254/fp/element_ops_arm64.s | 2 +- ecc/bn254/fr/element.go | 57 ++++ ecc/bn254/fr/element_ops_arm64.go | 24 -- ecc/bn254/fr/element_ops_arm64.s | 2 +- ecc/bw6-761/fr/element.go | 69 ++++ ecc/bw6-761/fr/element_ops_arm64.go | 24 -- ecc/bw6-761/fr/element_ops_arm64.s | 2 +- ecc/stark-curve/fp/element.go | 57 ++++ ecc/stark-curve/fp/element_ops_arm64.go | 24 -- ecc/stark-curve/fp/element_ops_arm64.s | 2 +- ecc/stark-curve/fr/element.go | 57 ++++ ecc/stark-curve/fr/element_ops_arm64.go | 24 -- ecc/stark-curve/fr/element_ops_arm64.s | 2 +- field/asm/element_4w_arm64.s | 223 ++++--------- field/asm/element_6w_arm64.s | 308 ++++++----------- field/generator/asm/arm64/build.go | 3 - field/generator/asm/arm64/element_ops.go | 309 ++++-------------- field/generator/generator.go | 1 + .../internal/templates/element/base.go | 4 +- .../internal/templates/element/ops_asm.go | 24 -- 40 files changed, 908 insertions(+), 913 deletions(-) diff --git a/ecc/bls12-377/fp/element.go b/ecc/bls12-377/fp/element.go index 408b94653d..393f45744d 100644 --- a/ecc/bls12-377/fp/element.go +++ b/ecc/bls12-377/fp/element.go @@ -418,6 +418,75 @@ func (z *Element) fromMont() *Element { return z } +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], carry = bits.Add64(x[3], y[3], carry) + z[4], carry = bits.Add64(x[4], y[4], carry) + z[5], _ = bits.Add64(x[5], y[5], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], carry = bits.Add64(x[3], x[3], carry) + z[4], carry = bits.Add64(x[4], x[4], carry) + z[5], _ = bits.Add64(x[5], x[5], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + z[4], b = bits.Sub64(x[4], y[4], b) + z[5], b = bits.Sub64(x[5], y[5], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + z[1], c = bits.Add64(z[1], q1, c) + z[2], c = bits.Add64(z[2], q2, c) + z[3], c = bits.Add64(z[3], q3, c) + z[4], c = bits.Add64(z[4], q4, c) + z[5], _ = bits.Add64(z[5], q5, c) + } + return z +} + // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bls12-377/fp/element_ops_arm64.go b/ecc/bls12-377/fp/element_ops_arm64.go index 1e0188e13a..78ae87b96b 100644 --- a/ecc/bls12-377/fp/element_ops_arm64.go +++ b/ecc/bls12-377/fp/element_ops_arm64.go @@ -19,30 +19,6 @@ package fp -//go:noescape -func add(res, x, y *Element) - -func (z *Element) Add(x, y *Element) *Element { - add(z, x, y) - return z -} - -//go:noescape -func double(res, x *Element) - -func (z *Element) Double(x *Element) *Element { - double(z, x) - return z -} - -//go:noescape -func sub(res, x, y *Element) - -func (z *Element) Sub(x, y *Element) *Element { - sub(z, x, y) - return z -} - //go:noescape func Butterfly(a, b *Element) diff --git a/ecc/bls12-377/fp/element_ops_arm64.s b/ecc/bls12-377/fp/element_ops_arm64.s index c4c5981603..f12adf4dc5 100644 --- a/ecc/bls12-377/fp/element_ops_arm64.s +++ b/ecc/bls12-377/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 635053461911122795 +// We include the hash to force the Go compiler to recompile: 4799084555005768587 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/bls12-377/fr/element.go b/ecc/bls12-377/fr/element.go index e37805eaa6..af277e8bb1 100644 --- a/ecc/bls12-377/fr/element.go +++ b/ecc/bls12-377/fr/element.go @@ -393,6 +393,63 @@ func (z *Element) fromMont() *Element { return z } +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], _ = bits.Add64(x[3], y[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], _ = bits.Add64(x[3], x[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + z[1], c = bits.Add64(z[1], q1, c) + z[2], c = bits.Add64(z[2], q2, c) + z[3], _ = bits.Add64(z[3], q3, c) + } + return z +} + // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bls12-377/fr/element_ops_arm64.go b/ecc/bls12-377/fr/element_ops_arm64.go index 9a57c7ca47..6759e524eb 100644 --- a/ecc/bls12-377/fr/element_ops_arm64.go +++ b/ecc/bls12-377/fr/element_ops_arm64.go @@ -19,30 +19,6 @@ package fr -//go:noescape -func add(res, x, y *Element) - -func (z *Element) Add(x, y *Element) *Element { - add(z, x, y) - return z -} - -//go:noescape -func double(res, x *Element) - -func (z *Element) Double(x *Element) *Element { - double(z, x) - return z -} - -//go:noescape -func sub(res, x, y *Element) - -func (z *Element) Sub(x, y *Element) *Element { - sub(z, x, y) - return z -} - //go:noescape func Butterfly(a, b *Element) diff --git a/ecc/bls12-377/fr/element_ops_arm64.s b/ecc/bls12-377/fr/element_ops_arm64.s index 1501fe56d2..6ba54c61aa 100644 --- a/ecc/bls12-377/fr/element_ops_arm64.s +++ b/ecc/bls12-377/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 11331350010912976978 +// We include the hash to force the Go compiler to recompile: 18027907654287790676 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls12-381/fp/element.go b/ecc/bls12-381/fp/element.go index 1609ca9524..f0bcfe51bc 100644 --- a/ecc/bls12-381/fp/element.go +++ b/ecc/bls12-381/fp/element.go @@ -418,6 +418,75 @@ func (z *Element) fromMont() *Element { return z } +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], carry = bits.Add64(x[3], y[3], carry) + z[4], carry = bits.Add64(x[4], y[4], carry) + z[5], _ = bits.Add64(x[5], y[5], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], carry = bits.Add64(x[3], x[3], carry) + z[4], carry = bits.Add64(x[4], x[4], carry) + z[5], _ = bits.Add64(x[5], x[5], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + z[4], b = bits.Sub64(x[4], y[4], b) + z[5], b = bits.Sub64(x[5], y[5], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + z[1], c = bits.Add64(z[1], q1, c) + z[2], c = bits.Add64(z[2], q2, c) + z[3], c = bits.Add64(z[3], q3, c) + z[4], c = bits.Add64(z[4], q4, c) + z[5], _ = bits.Add64(z[5], q5, c) + } + return z +} + // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bls12-381/fp/element_ops_arm64.go b/ecc/bls12-381/fp/element_ops_arm64.go index 1e0188e13a..78ae87b96b 100644 --- a/ecc/bls12-381/fp/element_ops_arm64.go +++ b/ecc/bls12-381/fp/element_ops_arm64.go @@ -19,30 +19,6 @@ package fp -//go:noescape -func add(res, x, y *Element) - -func (z *Element) Add(x, y *Element) *Element { - add(z, x, y) - return z -} - -//go:noescape -func double(res, x *Element) - -func (z *Element) Double(x *Element) *Element { - double(z, x) - return z -} - -//go:noescape -func sub(res, x, y *Element) - -func (z *Element) Sub(x, y *Element) *Element { - sub(z, x, y) - return z -} - //go:noescape func Butterfly(a, b *Element) diff --git a/ecc/bls12-381/fp/element_ops_arm64.s b/ecc/bls12-381/fp/element_ops_arm64.s index c4c5981603..f12adf4dc5 100644 --- a/ecc/bls12-381/fp/element_ops_arm64.s +++ b/ecc/bls12-381/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 635053461911122795 +// We include the hash to force the Go compiler to recompile: 4799084555005768587 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/bls12-381/fr/element.go b/ecc/bls12-381/fr/element.go index bcc9c6e251..dc38f08cd3 100644 --- a/ecc/bls12-381/fr/element.go +++ b/ecc/bls12-381/fr/element.go @@ -393,6 +393,63 @@ func (z *Element) fromMont() *Element { return z } +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], _ = bits.Add64(x[3], y[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], _ = bits.Add64(x[3], x[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + z[1], c = bits.Add64(z[1], q1, c) + z[2], c = bits.Add64(z[2], q2, c) + z[3], _ = bits.Add64(z[3], q3, c) + } + return z +} + // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bls12-381/fr/element_ops_arm64.go b/ecc/bls12-381/fr/element_ops_arm64.go index 9a57c7ca47..6759e524eb 100644 --- a/ecc/bls12-381/fr/element_ops_arm64.go +++ b/ecc/bls12-381/fr/element_ops_arm64.go @@ -19,30 +19,6 @@ package fr -//go:noescape -func add(res, x, y *Element) - -func (z *Element) Add(x, y *Element) *Element { - add(z, x, y) - return z -} - -//go:noescape -func double(res, x *Element) - -func (z *Element) Double(x *Element) *Element { - double(z, x) - return z -} - -//go:noescape -func sub(res, x, y *Element) - -func (z *Element) Sub(x, y *Element) *Element { - sub(z, x, y) - return z -} - //go:noescape func Butterfly(a, b *Element) diff --git a/ecc/bls12-381/fr/element_ops_arm64.s b/ecc/bls12-381/fr/element_ops_arm64.s index 1501fe56d2..6ba54c61aa 100644 --- a/ecc/bls12-381/fr/element_ops_arm64.s +++ b/ecc/bls12-381/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 11331350010912976978 +// We include the hash to force the Go compiler to recompile: 18027907654287790676 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-315/fr/element.go b/ecc/bls24-315/fr/element.go index f565c96360..abdb822acf 100644 --- a/ecc/bls24-315/fr/element.go +++ b/ecc/bls24-315/fr/element.go @@ -393,6 +393,63 @@ func (z *Element) fromMont() *Element { return z } +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], _ = bits.Add64(x[3], y[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], _ = bits.Add64(x[3], x[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + z[1], c = bits.Add64(z[1], q1, c) + z[2], c = bits.Add64(z[2], q2, c) + z[3], _ = bits.Add64(z[3], q3, c) + } + return z +} + // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bls24-315/fr/element_ops_arm64.go b/ecc/bls24-315/fr/element_ops_arm64.go index 9a57c7ca47..6759e524eb 100644 --- a/ecc/bls24-315/fr/element_ops_arm64.go +++ b/ecc/bls24-315/fr/element_ops_arm64.go @@ -19,30 +19,6 @@ package fr -//go:noescape -func add(res, x, y *Element) - -func (z *Element) Add(x, y *Element) *Element { - add(z, x, y) - return z -} - -//go:noescape -func double(res, x *Element) - -func (z *Element) Double(x *Element) *Element { - double(z, x) - return z -} - -//go:noescape -func sub(res, x, y *Element) - -func (z *Element) Sub(x, y *Element) *Element { - sub(z, x, y) - return z -} - //go:noescape func Butterfly(a, b *Element) diff --git a/ecc/bls24-315/fr/element_ops_arm64.s b/ecc/bls24-315/fr/element_ops_arm64.s index 1501fe56d2..6ba54c61aa 100644 --- a/ecc/bls24-315/fr/element_ops_arm64.s +++ b/ecc/bls24-315/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 11331350010912976978 +// We include the hash to force the Go compiler to recompile: 18027907654287790676 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-317/fr/element.go b/ecc/bls24-317/fr/element.go index f00c1ed57a..3aefaebe62 100644 --- a/ecc/bls24-317/fr/element.go +++ b/ecc/bls24-317/fr/element.go @@ -393,6 +393,63 @@ func (z *Element) fromMont() *Element { return z } +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], _ = bits.Add64(x[3], y[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], _ = bits.Add64(x[3], x[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + z[1], c = bits.Add64(z[1], q1, c) + z[2], c = bits.Add64(z[2], q2, c) + z[3], _ = bits.Add64(z[3], q3, c) + } + return z +} + // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bls24-317/fr/element_ops_arm64.go b/ecc/bls24-317/fr/element_ops_arm64.go index 9a57c7ca47..6759e524eb 100644 --- a/ecc/bls24-317/fr/element_ops_arm64.go +++ b/ecc/bls24-317/fr/element_ops_arm64.go @@ -19,30 +19,6 @@ package fr -//go:noescape -func add(res, x, y *Element) - -func (z *Element) Add(x, y *Element) *Element { - add(z, x, y) - return z -} - -//go:noescape -func double(res, x *Element) - -func (z *Element) Double(x *Element) *Element { - double(z, x) - return z -} - -//go:noescape -func sub(res, x, y *Element) - -func (z *Element) Sub(x, y *Element) *Element { - sub(z, x, y) - return z -} - //go:noescape func Butterfly(a, b *Element) diff --git a/ecc/bls24-317/fr/element_ops_arm64.s b/ecc/bls24-317/fr/element_ops_arm64.s index 1501fe56d2..6ba54c61aa 100644 --- a/ecc/bls24-317/fr/element_ops_arm64.s +++ b/ecc/bls24-317/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 11331350010912976978 +// We include the hash to force the Go compiler to recompile: 18027907654287790676 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fp/element.go b/ecc/bn254/fp/element.go index e58b0ac6b2..25fcdb67cc 100644 --- a/ecc/bn254/fp/element.go +++ b/ecc/bn254/fp/element.go @@ -393,6 +393,63 @@ func (z *Element) fromMont() *Element { return z } +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], _ = bits.Add64(x[3], y[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], _ = bits.Add64(x[3], x[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + z[1], c = bits.Add64(z[1], q1, c) + z[2], c = bits.Add64(z[2], q2, c) + z[3], _ = bits.Add64(z[3], q3, c) + } + return z +} + // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bn254/fp/element_ops_arm64.go b/ecc/bn254/fp/element_ops_arm64.go index 1e0188e13a..78ae87b96b 100644 --- a/ecc/bn254/fp/element_ops_arm64.go +++ b/ecc/bn254/fp/element_ops_arm64.go @@ -19,30 +19,6 @@ package fp -//go:noescape -func add(res, x, y *Element) - -func (z *Element) Add(x, y *Element) *Element { - add(z, x, y) - return z -} - -//go:noescape -func double(res, x *Element) - -func (z *Element) Double(x *Element) *Element { - double(z, x) - return z -} - -//go:noescape -func sub(res, x, y *Element) - -func (z *Element) Sub(x, y *Element) *Element { - sub(z, x, y) - return z -} - //go:noescape func Butterfly(a, b *Element) diff --git a/ecc/bn254/fp/element_ops_arm64.s b/ecc/bn254/fp/element_ops_arm64.s index 1501fe56d2..6ba54c61aa 100644 --- a/ecc/bn254/fp/element_ops_arm64.s +++ b/ecc/bn254/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 11331350010912976978 +// We include the hash to force the Go compiler to recompile: 18027907654287790676 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fr/element.go b/ecc/bn254/fr/element.go index 186b9e6a39..3650c954c5 100644 --- a/ecc/bn254/fr/element.go +++ b/ecc/bn254/fr/element.go @@ -393,6 +393,63 @@ func (z *Element) fromMont() *Element { return z } +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], _ = bits.Add64(x[3], y[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], _ = bits.Add64(x[3], x[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + z[1], c = bits.Add64(z[1], q1, c) + z[2], c = bits.Add64(z[2], q2, c) + z[3], _ = bits.Add64(z[3], q3, c) + } + return z +} + // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bn254/fr/element_ops_arm64.go b/ecc/bn254/fr/element_ops_arm64.go index 9a57c7ca47..6759e524eb 100644 --- a/ecc/bn254/fr/element_ops_arm64.go +++ b/ecc/bn254/fr/element_ops_arm64.go @@ -19,30 +19,6 @@ package fr -//go:noescape -func add(res, x, y *Element) - -func (z *Element) Add(x, y *Element) *Element { - add(z, x, y) - return z -} - -//go:noescape -func double(res, x *Element) - -func (z *Element) Double(x *Element) *Element { - double(z, x) - return z -} - -//go:noescape -func sub(res, x, y *Element) - -func (z *Element) Sub(x, y *Element) *Element { - sub(z, x, y) - return z -} - //go:noescape func Butterfly(a, b *Element) diff --git a/ecc/bn254/fr/element_ops_arm64.s b/ecc/bn254/fr/element_ops_arm64.s index 1501fe56d2..6ba54c61aa 100644 --- a/ecc/bn254/fr/element_ops_arm64.s +++ b/ecc/bn254/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 11331350010912976978 +// We include the hash to force the Go compiler to recompile: 18027907654287790676 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bw6-761/fr/element.go b/ecc/bw6-761/fr/element.go index a887b71537..6784bc911f 100644 --- a/ecc/bw6-761/fr/element.go +++ b/ecc/bw6-761/fr/element.go @@ -418,6 +418,75 @@ func (z *Element) fromMont() *Element { return z } +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], carry = bits.Add64(x[3], y[3], carry) + z[4], carry = bits.Add64(x[4], y[4], carry) + z[5], _ = bits.Add64(x[5], y[5], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], carry = bits.Add64(x[3], x[3], carry) + z[4], carry = bits.Add64(x[4], x[4], carry) + z[5], _ = bits.Add64(x[5], x[5], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + z[4], b = bits.Sub64(x[4], y[4], b) + z[5], b = bits.Sub64(x[5], y[5], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + z[1], c = bits.Add64(z[1], q1, c) + z[2], c = bits.Add64(z[2], q2, c) + z[3], c = bits.Add64(z[3], q3, c) + z[4], c = bits.Add64(z[4], q4, c) + z[5], _ = bits.Add64(z[5], q5, c) + } + return z +} + // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/bw6-761/fr/element_ops_arm64.go b/ecc/bw6-761/fr/element_ops_arm64.go index 9a57c7ca47..6759e524eb 100644 --- a/ecc/bw6-761/fr/element_ops_arm64.go +++ b/ecc/bw6-761/fr/element_ops_arm64.go @@ -19,30 +19,6 @@ package fr -//go:noescape -func add(res, x, y *Element) - -func (z *Element) Add(x, y *Element) *Element { - add(z, x, y) - return z -} - -//go:noescape -func double(res, x *Element) - -func (z *Element) Double(x *Element) *Element { - double(z, x) - return z -} - -//go:noescape -func sub(res, x, y *Element) - -func (z *Element) Sub(x, y *Element) *Element { - sub(z, x, y) - return z -} - //go:noescape func Butterfly(a, b *Element) diff --git a/ecc/bw6-761/fr/element_ops_arm64.s b/ecc/bw6-761/fr/element_ops_arm64.s index c4c5981603..f12adf4dc5 100644 --- a/ecc/bw6-761/fr/element_ops_arm64.s +++ b/ecc/bw6-761/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 635053461911122795 +// We include the hash to force the Go compiler to recompile: 4799084555005768587 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/stark-curve/fp/element.go b/ecc/stark-curve/fp/element.go index dc0de49c67..1c53dcb090 100644 --- a/ecc/stark-curve/fp/element.go +++ b/ecc/stark-curve/fp/element.go @@ -393,6 +393,63 @@ func (z *Element) fromMont() *Element { return z } +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], _ = bits.Add64(x[3], y[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], _ = bits.Add64(x[3], x[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + z[1], c = bits.Add64(z[1], q1, c) + z[2], c = bits.Add64(z[2], q2, c) + z[3], _ = bits.Add64(z[3], q3, c) + } + return z +} + // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/stark-curve/fp/element_ops_arm64.go b/ecc/stark-curve/fp/element_ops_arm64.go index 1e0188e13a..78ae87b96b 100644 --- a/ecc/stark-curve/fp/element_ops_arm64.go +++ b/ecc/stark-curve/fp/element_ops_arm64.go @@ -19,30 +19,6 @@ package fp -//go:noescape -func add(res, x, y *Element) - -func (z *Element) Add(x, y *Element) *Element { - add(z, x, y) - return z -} - -//go:noescape -func double(res, x *Element) - -func (z *Element) Double(x *Element) *Element { - double(z, x) - return z -} - -//go:noescape -func sub(res, x, y *Element) - -func (z *Element) Sub(x, y *Element) *Element { - sub(z, x, y) - return z -} - //go:noescape func Butterfly(a, b *Element) diff --git a/ecc/stark-curve/fp/element_ops_arm64.s b/ecc/stark-curve/fp/element_ops_arm64.s index 1501fe56d2..6ba54c61aa 100644 --- a/ecc/stark-curve/fp/element_ops_arm64.s +++ b/ecc/stark-curve/fp/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 11331350010912976978 +// We include the hash to force the Go compiler to recompile: 18027907654287790676 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fr/element.go b/ecc/stark-curve/fr/element.go index 65c70f9722..216e287ebb 100644 --- a/ecc/stark-curve/fr/element.go +++ b/ecc/stark-curve/fr/element.go @@ -393,6 +393,63 @@ func (z *Element) fromMont() *Element { return z } +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], _ = bits.Add64(x[3], y[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], _ = bits.Add64(x[3], x[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + z[1], c = bits.Add64(z[1], q1, c) + z[2], c = bits.Add64(z[2], q2, c) + z[3], _ = bits.Add64(z[3], q3, c) + } + return z +} + // Neg z = q - x func (z *Element) Neg(x *Element) *Element { if x.IsZero() { diff --git a/ecc/stark-curve/fr/element_ops_arm64.go b/ecc/stark-curve/fr/element_ops_arm64.go index 9a57c7ca47..6759e524eb 100644 --- a/ecc/stark-curve/fr/element_ops_arm64.go +++ b/ecc/stark-curve/fr/element_ops_arm64.go @@ -19,30 +19,6 @@ package fr -//go:noescape -func add(res, x, y *Element) - -func (z *Element) Add(x, y *Element) *Element { - add(z, x, y) - return z -} - -//go:noescape -func double(res, x *Element) - -func (z *Element) Double(x *Element) *Element { - double(z, x) - return z -} - -//go:noescape -func sub(res, x, y *Element) - -func (z *Element) Sub(x, y *Element) *Element { - sub(z, x, y) - return z -} - //go:noescape func Butterfly(a, b *Element) diff --git a/ecc/stark-curve/fr/element_ops_arm64.s b/ecc/stark-curve/fr/element_ops_arm64.s index 1501fe56d2..6ba54c61aa 100644 --- a/ecc/stark-curve/fr/element_ops_arm64.s +++ b/ecc/stark-curve/fr/element_ops_arm64.s @@ -1,6 +1,6 @@ // +build !purego // Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 11331350010912976978 +// We include the hash to force the Go compiler to recompile: 18027907654287790676 #include "../../../field/asm/element_4w_arm64.s" diff --git a/field/asm/element_4w_arm64.s b/field/asm/element_4w_arm64.s index 91a34e057b..071e2495c3 100644 --- a/field/asm/element_4w_arm64.s +++ b/field/asm/element_4w_arm64.s @@ -3,94 +3,8 @@ #include "funcdata.h" #include "go_asm.h" -// add(res, x, y *Element) -TEXT ·add(SB), NOFRAME|NOSPLIT, $0-24 - LDP x+8(FP), (R12, R13) - LDP 0(R12), (R8, R9) - LDP 16(R12), (R10, R11) - LDP 0(R13), (R4, R5) - LDP 16(R13), (R6, R7) - LDP ·qElement+0(SB), (R0, R1) - LDP ·qElement+16(SB), (R2, R3) - ADDS R8, R4, R4 - ADCS R9, R5, R5 - ADCS R10, R6, R6 - ADCS R11, R7, R7 - MOVD res+0(FP), R14 - - // load modulus and subtract - SUBS R0, R4, R0 - SBCS R1, R5, R1 - SBCS R2, R6, R2 - SBCS R3, R7, R3 - - // reduce if necessary - CSEL CS, R0, R4, R4 - CSEL CS, R1, R5, R5 - STP (R4, R5), 0(R14) - CSEL CS, R2, R6, R6 - CSEL CS, R3, R7, R7 - STP (R6, R7), 16(R14) - RET - -// double(res, x *Element) -TEXT ·double(SB), NOFRAME|NOSPLIT, $0-16 - LDP res+0(FP), (R1, R0) - LDP 0(R0), (R2, R3) - LDP 16(R0), (R4, R5) - ADDS R2, R2, R2 - ADCS R3, R3, R3 - ADCS R4, R4, R4 - ADCS R5, R5, R5 - - // load modulus and subtract - LDP ·qElement+0(SB), (R6, R7) - LDP ·qElement+16(SB), (R8, R9) - SUBS R6, R2, R6 - SBCS R7, R3, R7 - SBCS R8, R4, R8 - SBCS R9, R5, R9 - - // reduce if necessary - CSEL CS, R6, R2, R2 - CSEL CS, R7, R3, R3 - CSEL CS, R8, R4, R4 - CSEL CS, R9, R5, R5 - STP (R2, R3), 0(R1) - STP (R4, R5), 16(R1) - RET - -// sub(res, x, y *Element) -TEXT ·sub(SB), NOFRAME|NOSPLIT, $0-24 - LDP x+8(FP), (R12, R13) - LDP 0(R12), (R4, R5) - LDP 16(R12), (R6, R7) - LDP 0(R13), (R0, R1) - LDP 16(R13), (R2, R3) - SUBS R0, R4, R0 - SBCS R1, R5, R1 - SBCS R2, R6, R2 - SBCS R3, R7, R3 - - // load modulus and select - LDP ·qElement+0(SB), (R8, R9) - LDP ·qElement+16(SB), (R10, R11) - CSEL CS, ZR, R8, R8 - CSEL CS, ZR, R9, R9 - CSEL CS, ZR, R10, R10 - CSEL CS, ZR, R11, R11 - - // add q if underflow, 0 if not - ADDS R0, R8, R0 - ADCS R1, R9, R1 - ADCS R2, R10, R2 - ADCS R3, R11, R3 - MOVD res+0(FP), R14 - STP (R0, R1), 0(R14) - STP (R2, R3), 16(R14) - RET - -// butterfly(x, y *Element) +// butterfly(a, b *Element) +// a, b = a+b, a-b TEXT ·Butterfly(SB), NOFRAME|NOSPLIT, $0-16 LDP x+0(FP), (R16, R17) LDP 0(R16), (R0, R1) @@ -105,8 +19,6 @@ TEXT ·Butterfly(SB), NOFRAME|NOSPLIT, $0-16 SBCS R5, R1, R5 SBCS R6, R2, R6 SBCS R7, R3, R7 - - // load modulus and select LDP ·qElement+0(SB), (R0, R1) CSEL CS, ZR, R0, R12 CSEL CS, ZR, R1, R13 @@ -122,13 +34,13 @@ TEXT ·Butterfly(SB), NOFRAME|NOSPLIT, $0-16 ADC R7, R15, R7 STP (R6, R7), 16(R17) - // load modulus and subtract + // q = t - q SUBS R0, R8, R0 SBCS R1, R9, R1 SBCS R2, R10, R2 SBCS R3, R11, R3 - // reduce if necessary + // if no borrow, return q, else return t CSEL CS, R0, R8, R8 CSEL CS, R1, R9, R9 STP (R8, R9), 0(R16) @@ -138,72 +50,73 @@ TEXT ·Butterfly(SB), NOFRAME|NOSPLIT, $0-16 RET // mul(res, x, y *Element) +// Algorithm 2 of Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS +// by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 TEXT ·mul(SB), NOFRAME|NOSPLIT, $0-24 #define DIVSHIFT() \ - MUL R12, R17, R0 \ - ADDS R0, R7, R7 \ - MUL R13, R17, R0 \ - ADCS R0, R8, R8 \ - MUL R14, R17, R0 \ - ADCS R0, R9, R9 \ - MUL R15, R17, R0 \ - ADCS R0, R10, R10 \ - ADC ZR, R11, R11 \ - UMULH R12, R17, R0 \ - ADDS R0, R8, R7 \ - UMULH R13, R17, R0 \ - ADCS R0, R9, R8 \ - UMULH R14, R17, R0 \ - ADCS R0, R10, R9 \ - UMULH R15, R17, R0 \ - ADCS R0, R11, R10 \ + MUL R7, R17, R0 \ + ADDS R0, R11, R11 \ + MUL R8, R17, R0 \ + ADCS R0, R12, R12 \ + MUL R9, R17, R0 \ + ADCS R0, R13, R13 \ + MUL R10, R17, R0 \ + ADCS R0, R14, R14 \ + ADC R15, ZR, R15 \ + UMULH R7, R17, R0 \ + ADDS R0, R12, R11 \ + UMULH R8, R17, R0 \ + ADCS R0, R13, R12 \ + UMULH R9, R17, R0 \ + ADCS R0, R14, R13 \ + UMULH R10, R17, R0 \ + ADCS R0, R15, R14 \ #define MUL_WORD_N() \ - MUL R3, R2, R0 \ - ADDS R0, R7, R7 \ - MUL R7, R16, R17 \ - MUL R4, R2, R0 \ - ADCS R0, R8, R8 \ - MUL R5, R2, R0 \ - ADCS R0, R9, R9 \ - MUL R6, R2, R0 \ - ADCS R0, R10, R10 \ - ADC ZR, ZR, R11 \ - UMULH R3, R2, R0 \ - ADDS R0, R8, R8 \ - UMULH R4, R2, R0 \ - ADCS R0, R9, R9 \ - UMULH R5, R2, R0 \ - ADCS R0, R10, R10 \ - UMULH R6, R2, R0 \ - ADC R0, R11, R11 \ - DIVSHIFT() \ + MUL R3, R2, R0 \ + ADDS R0, R11, R11 \ + MUL R11, R16, R17 \ + MUL R4, R2, R0 \ + ADCS R0, R12, R12 \ + MUL R5, R2, R0 \ + ADCS R0, R13, R13 \ + MUL R6, R2, R0 \ + ADCS R0, R14, R14 \ + ADC ZR, ZR, R15 \ + UMULH R3, R2, R0 \ + ADDS R0, R12, R12 \ + UMULH R4, R2, R0 \ + ADCS R0, R13, R13 \ + UMULH R5, R2, R0 \ + ADCS R0, R14, R14 \ + UMULH R6, R2, R0 \ + ADC R0, R15, R15 \ + DIVSHIFT() \ #define MUL_WORD_0() \ - MUL R3, R2, R7 \ - MUL R4, R2, R8 \ - MUL R5, R2, R9 \ - MUL R6, R2, R10 \ - UMULH R3, R2, R0 \ - ADDS R0, R8, R8 \ - UMULH R4, R2, R0 \ - ADCS R0, R9, R9 \ - UMULH R5, R2, R0 \ - ADCS R0, R10, R10 \ - UMULH R6, R2, R0 \ - ADC R0, ZR, R11 \ - MUL R7, R16, R17 \ - DIVSHIFT() \ + MUL R3, R2, R11 \ + MUL R4, R2, R12 \ + MUL R5, R2, R13 \ + MUL R6, R2, R14 \ + UMULH R3, R2, R0 \ + ADDS R0, R12, R12 \ + UMULH R4, R2, R0 \ + ADCS R0, R13, R13 \ + UMULH R5, R2, R0 \ + ADCS R0, R14, R14 \ + UMULH R6, R2, R0 \ + ADC R0, ZR, R15 \ + MUL R11, R16, R17 \ + DIVSHIFT() \ - // mul body MOVD y+16(FP), R1 MOVD x+8(FP), R0 LDP 0(R0), (R3, R4) LDP 16(R0), (R5, R6) MOVD 0(R1), R2 MOVD $const_qInvNeg, R16 - LDP ·qElement+0(SB), (R12, R13) - LDP ·qElement+16(SB), (R14, R15) + LDP ·qElement+0(SB), (R7, R8) + LDP ·qElement+16(SB), (R9, R10) MUL_WORD_0() MOVD 8(R1), R2 MUL_WORD_N() @@ -213,15 +126,15 @@ TEXT ·mul(SB), NOFRAME|NOSPLIT, $0-24 MUL_WORD_N() // reduce if necessary - SUBS R12, R7, R12 - SBCS R13, R8, R13 - SBCS R14, R9, R14 - SBCS R15, R10, R15 + SUBS R7, R11, R7 + SBCS R8, R12, R8 + SBCS R9, R13, R9 + SBCS R10, R14, R10 MOVD res+0(FP), R0 - CSEL CS, R12, R7, R7 - CSEL CS, R13, R8, R8 - STP (R7, R8), 0(R0) - CSEL CS, R14, R9, R9 - CSEL CS, R15, R10, R10 - STP (R9, R10), 16(R0) + CSEL CS, R7, R11, R11 + CSEL CS, R8, R12, R12 + STP (R11, R12), 0(R0) + CSEL CS, R9, R13, R13 + CSEL CS, R10, R14, R14 + STP (R13, R14), 16(R0) RET diff --git a/field/asm/element_6w_arm64.s b/field/asm/element_6w_arm64.s index c4a5ffd7bb..bcba9ee6b1 100644 --- a/field/asm/element_6w_arm64.s +++ b/field/asm/element_6w_arm64.s @@ -3,123 +3,8 @@ #include "funcdata.h" #include "go_asm.h" -// add(res, x, y *Element) -TEXT ·add(SB), NOFRAME|NOSPLIT, $0-24 - LDP x+8(FP), (R19, R20) - LDP 0(R19), (R12, R13) - LDP 16(R19), (R14, R15) - LDP 32(R19), (R16, R17) - LDP 0(R20), (R6, R7) - LDP 16(R20), (R8, R9) - LDP 32(R20), (R10, R11) - LDP ·qElement+0(SB), (R0, R1) - LDP ·qElement+16(SB), (R2, R3) - LDP ·qElement+32(SB), (R4, R5) - ADDS R12, R6, R6 - ADCS R13, R7, R7 - ADCS R14, R8, R8 - ADCS R15, R9, R9 - ADCS R16, R10, R10 - ADCS R17, R11, R11 - MOVD res+0(FP), R21 - - // load modulus and subtract - SUBS R0, R6, R0 - SBCS R1, R7, R1 - SBCS R2, R8, R2 - SBCS R3, R9, R3 - SBCS R4, R10, R4 - SBCS R5, R11, R5 - - // reduce if necessary - CSEL CS, R0, R6, R6 - CSEL CS, R1, R7, R7 - STP (R6, R7), 0(R21) - CSEL CS, R2, R8, R8 - CSEL CS, R3, R9, R9 - STP (R8, R9), 16(R21) - CSEL CS, R4, R10, R10 - CSEL CS, R5, R11, R11 - STP (R10, R11), 32(R21) - RET - -// double(res, x *Element) -TEXT ·double(SB), NOFRAME|NOSPLIT, $0-16 - LDP res+0(FP), (R1, R0) - LDP 0(R0), (R2, R3) - LDP 16(R0), (R4, R5) - LDP 32(R0), (R6, R7) - ADDS R2, R2, R2 - ADCS R3, R3, R3 - ADCS R4, R4, R4 - ADCS R5, R5, R5 - ADCS R6, R6, R6 - ADCS R7, R7, R7 - - // load modulus and subtract - LDP ·qElement+0(SB), (R8, R9) - LDP ·qElement+16(SB), (R10, R11) - LDP ·qElement+32(SB), (R12, R13) - SUBS R8, R2, R8 - SBCS R9, R3, R9 - SBCS R10, R4, R10 - SBCS R11, R5, R11 - SBCS R12, R6, R12 - SBCS R13, R7, R13 - - // reduce if necessary - CSEL CS, R8, R2, R2 - CSEL CS, R9, R3, R3 - CSEL CS, R10, R4, R4 - CSEL CS, R11, R5, R5 - CSEL CS, R12, R6, R6 - CSEL CS, R13, R7, R7 - STP (R2, R3), 0(R1) - STP (R4, R5), 16(R1) - STP (R6, R7), 32(R1) - RET - -// sub(res, x, y *Element) -TEXT ·sub(SB), NOFRAME|NOSPLIT, $0-24 - LDP x+8(FP), (R19, R20) - LDP 0(R19), (R6, R7) - LDP 16(R19), (R8, R9) - LDP 32(R19), (R10, R11) - LDP 0(R20), (R0, R1) - LDP 16(R20), (R2, R3) - LDP 32(R20), (R4, R5) - SUBS R0, R6, R0 - SBCS R1, R7, R1 - SBCS R2, R8, R2 - SBCS R3, R9, R3 - SBCS R4, R10, R4 - SBCS R5, R11, R5 - - // load modulus and select - LDP ·qElement+0(SB), (R12, R13) - LDP ·qElement+16(SB), (R14, R15) - LDP ·qElement+32(SB), (R16, R17) - CSEL CS, ZR, R12, R12 - CSEL CS, ZR, R13, R13 - CSEL CS, ZR, R14, R14 - CSEL CS, ZR, R15, R15 - CSEL CS, ZR, R16, R16 - CSEL CS, ZR, R17, R17 - - // add q if underflow, 0 if not - ADDS R0, R12, R0 - ADCS R1, R13, R1 - ADCS R2, R14, R2 - ADCS R3, R15, R3 - ADCS R4, R16, R4 - ADCS R5, R17, R5 - MOVD res+0(FP), R21 - STP (R0, R1), 0(R21) - STP (R2, R3), 16(R21) - STP (R4, R5), 32(R21) - RET - -// butterfly(x, y *Element) +// butterfly(a, b *Element) +// a, b = a+b, a-b TEXT ·Butterfly(SB), NOFRAME|NOSPLIT, $0-16 LDP x+0(FP), (R25, R26) LDP 0(R25), (R0, R1) @@ -140,8 +25,6 @@ TEXT ·Butterfly(SB), NOFRAME|NOSPLIT, $0-16 SBCS R9, R3, R9 SBCS R10, R4, R10 SBCS R11, R5, R11 - - // load modulus and select LDP ·qElement+0(SB), (R0, R1) CSEL CS, ZR, R0, R19 CSEL CS, ZR, R1, R20 @@ -163,7 +46,7 @@ TEXT ·Butterfly(SB), NOFRAME|NOSPLIT, $0-16 ADC R11, R24, R11 STP (R10, R11), 32(R26) - // load modulus and subtract + // q = t - q SUBS R0, R12, R0 SBCS R1, R13, R1 SBCS R2, R14, R2 @@ -171,7 +54,7 @@ TEXT ·Butterfly(SB), NOFRAME|NOSPLIT, $0-16 SBCS R4, R16, R4 SBCS R5, R17, R5 - // reduce if necessary + // if no borrow, return q, else return t CSEL CS, R0, R12, R12 CSEL CS, R1, R13, R13 STP (R12, R13), 0(R25) @@ -184,86 +67,87 @@ TEXT ·Butterfly(SB), NOFRAME|NOSPLIT, $0-16 RET // mul(res, x, y *Element) +// Algorithm 2 of Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS +// by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 TEXT ·mul(SB), NOFRAME|NOSPLIT, $0-24 #define DIVSHIFT() \ - MUL R16, R24, R0 \ - ADDS R0, R9, R9 \ - MUL R17, R24, R0 \ - ADCS R0, R10, R10 \ - MUL R19, R24, R0 \ - ADCS R0, R11, R11 \ - MUL R20, R24, R0 \ - ADCS R0, R12, R12 \ - MUL R21, R24, R0 \ - ADCS R0, R13, R13 \ - MUL R22, R24, R0 \ - ADCS R0, R14, R14 \ - ADC ZR, R15, R15 \ - UMULH R16, R24, R0 \ - ADDS R0, R10, R9 \ - UMULH R17, R24, R0 \ - ADCS R0, R11, R10 \ - UMULH R19, R24, R0 \ - ADCS R0, R12, R11 \ - UMULH R20, R24, R0 \ - ADCS R0, R13, R12 \ - UMULH R21, R24, R0 \ - ADCS R0, R14, R13 \ - UMULH R22, R24, R0 \ - ADCS R0, R15, R14 \ + MUL R9, R24, R0 \ + ADDS R0, R15, R15 \ + MUL R10, R24, R0 \ + ADCS R0, R16, R16 \ + MUL R11, R24, R0 \ + ADCS R0, R17, R17 \ + MUL R12, R24, R0 \ + ADCS R0, R19, R19 \ + MUL R13, R24, R0 \ + ADCS R0, R20, R20 \ + MUL R14, R24, R0 \ + ADCS R0, R21, R21 \ + ADC R22, ZR, R22 \ + UMULH R9, R24, R0 \ + ADDS R0, R16, R15 \ + UMULH R10, R24, R0 \ + ADCS R0, R17, R16 \ + UMULH R11, R24, R0 \ + ADCS R0, R19, R17 \ + UMULH R12, R24, R0 \ + ADCS R0, R20, R19 \ + UMULH R13, R24, R0 \ + ADCS R0, R21, R20 \ + UMULH R14, R24, R0 \ + ADCS R0, R22, R21 \ #define MUL_WORD_N() \ - MUL R3, R2, R0 \ - ADDS R0, R9, R9 \ - MUL R9, R23, R24 \ - MUL R4, R2, R0 \ - ADCS R0, R10, R10 \ - MUL R5, R2, R0 \ - ADCS R0, R11, R11 \ - MUL R6, R2, R0 \ - ADCS R0, R12, R12 \ - MUL R7, R2, R0 \ - ADCS R0, R13, R13 \ - MUL R8, R2, R0 \ - ADCS R0, R14, R14 \ - ADC ZR, ZR, R15 \ - UMULH R3, R2, R0 \ - ADDS R0, R10, R10 \ - UMULH R4, R2, R0 \ - ADCS R0, R11, R11 \ - UMULH R5, R2, R0 \ - ADCS R0, R12, R12 \ - UMULH R6, R2, R0 \ - ADCS R0, R13, R13 \ - UMULH R7, R2, R0 \ - ADCS R0, R14, R14 \ - UMULH R8, R2, R0 \ - ADC R0, R15, R15 \ - DIVSHIFT() \ + MUL R3, R2, R0 \ + ADDS R0, R15, R15 \ + MUL R15, R23, R24 \ + MUL R4, R2, R0 \ + ADCS R0, R16, R16 \ + MUL R5, R2, R0 \ + ADCS R0, R17, R17 \ + MUL R6, R2, R0 \ + ADCS R0, R19, R19 \ + MUL R7, R2, R0 \ + ADCS R0, R20, R20 \ + MUL R8, R2, R0 \ + ADCS R0, R21, R21 \ + ADC ZR, ZR, R22 \ + UMULH R3, R2, R0 \ + ADDS R0, R16, R16 \ + UMULH R4, R2, R0 \ + ADCS R0, R17, R17 \ + UMULH R5, R2, R0 \ + ADCS R0, R19, R19 \ + UMULH R6, R2, R0 \ + ADCS R0, R20, R20 \ + UMULH R7, R2, R0 \ + ADCS R0, R21, R21 \ + UMULH R8, R2, R0 \ + ADC R0, R22, R22 \ + DIVSHIFT() \ #define MUL_WORD_0() \ - MUL R3, R2, R9 \ - MUL R4, R2, R10 \ - MUL R5, R2, R11 \ - MUL R6, R2, R12 \ - MUL R7, R2, R13 \ - MUL R8, R2, R14 \ - UMULH R3, R2, R0 \ - ADDS R0, R10, R10 \ - UMULH R4, R2, R0 \ - ADCS R0, R11, R11 \ - UMULH R5, R2, R0 \ - ADCS R0, R12, R12 \ - UMULH R6, R2, R0 \ - ADCS R0, R13, R13 \ - UMULH R7, R2, R0 \ - ADCS R0, R14, R14 \ - UMULH R8, R2, R0 \ - ADC R0, ZR, R15 \ - MUL R9, R23, R24 \ - DIVSHIFT() \ + MUL R3, R2, R15 \ + MUL R4, R2, R16 \ + MUL R5, R2, R17 \ + MUL R6, R2, R19 \ + MUL R7, R2, R20 \ + MUL R8, R2, R21 \ + UMULH R3, R2, R0 \ + ADDS R0, R16, R16 \ + UMULH R4, R2, R0 \ + ADCS R0, R17, R17 \ + UMULH R5, R2, R0 \ + ADCS R0, R19, R19 \ + UMULH R6, R2, R0 \ + ADCS R0, R20, R20 \ + UMULH R7, R2, R0 \ + ADCS R0, R21, R21 \ + UMULH R8, R2, R0 \ + ADC R0, ZR, R22 \ + MUL R15, R23, R24 \ + DIVSHIFT() \ - // mul body MOVD y+16(FP), R1 MOVD x+8(FP), R0 LDP 0(R0), (R3, R4) @@ -271,9 +155,9 @@ TEXT ·mul(SB), NOFRAME|NOSPLIT, $0-24 LDP 32(R0), (R7, R8) MOVD 0(R1), R2 MOVD $const_qInvNeg, R23 - LDP ·qElement+0(SB), (R16, R17) - LDP ·qElement+16(SB), (R19, R20) - LDP ·qElement+32(SB), (R21, R22) + LDP ·qElement+0(SB), (R9, R10) + LDP ·qElement+16(SB), (R11, R12) + LDP ·qElement+32(SB), (R13, R14) MUL_WORD_0() MOVD 8(R1), R2 MUL_WORD_N() @@ -287,20 +171,20 @@ TEXT ·mul(SB), NOFRAME|NOSPLIT, $0-24 MUL_WORD_N() // reduce if necessary - SUBS R16, R9, R16 - SBCS R17, R10, R17 - SBCS R19, R11, R19 - SBCS R20, R12, R20 - SBCS R21, R13, R21 - SBCS R22, R14, R22 + SUBS R9, R15, R9 + SBCS R10, R16, R10 + SBCS R11, R17, R11 + SBCS R12, R19, R12 + SBCS R13, R20, R13 + SBCS R14, R21, R14 MOVD res+0(FP), R0 - CSEL CS, R16, R9, R9 - CSEL CS, R17, R10, R10 - STP (R9, R10), 0(R0) - CSEL CS, R19, R11, R11 - CSEL CS, R20, R12, R12 - STP (R11, R12), 16(R0) - CSEL CS, R21, R13, R13 - CSEL CS, R22, R14, R14 - STP (R13, R14), 32(R0) + CSEL CS, R9, R15, R15 + CSEL CS, R10, R16, R16 + STP (R15, R16), 0(R0) + CSEL CS, R11, R17, R17 + CSEL CS, R12, R19, R19 + STP (R17, R19), 16(R0) + CSEL CS, R13, R20, R20 + CSEL CS, R14, R21, R21 + STP (R20, R21), 32(R0) RET diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index 3734a8078a..ead32ae665 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -107,9 +107,6 @@ func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { panic("NbWords must be even") } - f.generateAdd() - f.generateDouble() - f.generateSub() f.generateButterfly() f.generateMul() diff --git a/field/generator/asm/arm64/element_ops.go b/field/generator/asm/arm64/element_ops.go index 14415944be..8d427a1286 100644 --- a/field/generator/asm/arm64/element_ops.go +++ b/field/generator/asm/arm64/element_ops.go @@ -18,126 +18,16 @@ import ( "github.com/consensys/bavard/arm64" ) -func (f *FFArm64) generateAdd() { - f.Comment("add(res, x, y *Element)") - registers := f.FnHeader("add", 0, 24) - defer f.AssertCleanStack(0, 0) - - // registers - q := registers.PopN(f.NbWords) - z := registers.PopN(f.NbWords) - x := registers.PopN(f.NbWords) - xPtr := registers.Pop() - yPtr := registers.Pop() - zPtr := registers.Pop() - - f.LDP("x+8(FP)", xPtr, yPtr) - - f.load(xPtr, x) - f.load(yPtr, z) - for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(f.qAt(i), q[i], q[i+1]) - } - - f.ADDS(x[0], z[0], z[0]) - for i := 1; i < f.NbWords; i++ { - f.ADCS(x[i], z[i], z[i]) - } - - f.MOVD("res+0(FP)", zPtr) - f.reduceAndStore(z, q, zPtr) - - f.RET() - -} - -func (f *FFArm64) generateDouble() { - f.Comment("double(res, x *Element)") - registers := f.FnHeader("double", 0, 16) - defer f.AssertCleanStack(0, 0) - - // registers - xPtr := registers.Pop() - zPtr := registers.Pop() - z := registers.PopN(f.NbWords) - t := registers.PopN(f.NbWords) - - f.LDP("res+0(FP)", zPtr, xPtr) - - f.load(xPtr, z) - - f.ADDS(z[0], z[0], z[0]) - for i := 1; i < f.NbWords; i++ { - f.ADCS(z[i], z[i], z[i]) - } - - f.reduce(z, t) - - f.store(z, zPtr) - - f.RET() - -} - -// generateSub NO LONGER uses one more register than generateAdd, but that's okay since we have 29 registers available. -func (f *FFArm64) generateSub() { - f.Comment("sub(res, x, y *Element)") - - registers := f.FnHeader("sub", 0, 24) - defer f.AssertCleanStack(0, 0) - - // registers - z := registers.PopN(f.NbWords) - x := registers.PopN(f.NbWords) - t := registers.PopN(f.NbWords) - xPtr := registers.Pop() - yPtr := registers.Pop() - zPtr := registers.Pop() - - f.LDP("x+8(FP)", xPtr, yPtr) - - f.load(xPtr, x) - f.load(yPtr, z) - - f.SUBS(z[0], x[0], z[0]) - for i := 1; i < f.NbWords; i++ { - f.SBCS(z[i], x[i], z[i]) - } - - f.Comment("load modulus and select") - - zero := arm64.Register("ZR") - - for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(f.qAt(i), t[i], t[i+1]) - } - for i := 0; i < f.NbWords; i++ { - f.CSEL("CS", zero, t[i], t[i]) - } - f.Comment("add q if underflow, 0 if not") - f.ADDS(z[0], t[0], z[0]) - for i := 1; i < f.NbWords; i++ { - f.ADCS(z[i], t[i], z[i]) - } - - f.MOVD("res+0(FP)", zPtr) - f.store(z, zPtr) - - f.RET() - -} - func (f *FFArm64) generateButterfly() { - f.Comment("butterfly(x, y *Element)") + f.Comment("butterfly(a, b *Element)") + f.Comment("a, b = a+b, a-b") registers := f.FnHeader("Butterfly", 0, 16) defer f.AssertCleanStack(0, 0) - // Butterfly sets - // a = a + b (mod q) - // b = a - b (mod q) + // registers a := registers.PopN(f.NbWords) b := registers.PopN(f.NbWords) - aRes := registers.PopN(f.NbWords) + r := registers.PopN(f.NbWords) t := registers.PopN(f.NbWords) aPtr := registers.Pop() bPtr := registers.Pop() @@ -146,62 +36,38 @@ func (f *FFArm64) generateButterfly() { f.load(aPtr, a) f.load(bPtr, b) - f.ADDS(a[0], b[0], aRes[0]) - for i := 1; i < f.NbWords; i++ { - - if i == f.NbWordsLastIndex { - f.ADC(a[i], b[i], aRes[i]) - } else { - f.ADCS(a[i], b[i], aRes[i]) - } + for i := 0; i < f.NbWords; i++ { + f.add0n(i)(a[i], b[i], r[i]) } - // f.reduce(aRes, t) - - // f.Comment("store") - - // f.store(aRes, aPtr) - - bRes := b - - f.SUBS(b[0], a[0], bRes[0]) + f.SUBS(b[0], a[0], b[0]) for i := 1; i < f.NbWords; i++ { - f.SBCS(b[i], a[i], bRes[i]) + f.SBCS(b[i], a[i], b[i]) } - f.Comment("load modulus and select") - - zero := arm64.Register("ZR") - - // for i := 0; i < f.NbWords-1; i += 2 { - // f.LDP(f.qAt(i), t[i], t[i+1]) - // } for i := 0; i < f.NbWords; i++ { if i%2 == 0 { f.LDP(f.qAt(i), a[i], a[i+1]) } - f.CSEL("CS", zero, a[i], t[i]) + f.CSEL("CS", "ZR", a[i], t[i]) } f.Comment("add q if underflow, 0 if not") - f.ADDS(bRes[0], t[0], bRes[0]) - for i := 1; i < f.NbWords; i++ { - if i == f.NbWordsLastIndex { - f.ADC(bRes[i], t[i], bRes[i]) - } else { - f.ADCS(bRes[i], t[i], bRes[i]) - } + for i := 0; i < f.NbWords; i++ { + f.add0n(i)(b[i], t[i], b[i]) if i%2 == 1 { - f.STP(bRes[i-1], bRes[i], bPtr.At(i-1)) + f.STP(b[i-1], b[i], bPtr.At(i-1)) } } - f.reduceAndStore(aRes, a, aPtr) + f.reduceAndStore(r, a, aPtr) f.RET() } func (f *FFArm64) generateMul() { f.Comment("mul(res, x, y *Element)") + f.Comment("Algorithm 2 of Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS") + f.Comment("by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521") registers := f.FnHeader("mul", 0, 24) defer f.AssertCleanStack(0, 0) @@ -209,38 +75,27 @@ func (f *FFArm64) generateMul() { yPtr := registers.Pop() bi := registers.Pop() a := registers.PopN(f.NbWords) - t := registers.PopN(f.NbWords + 1) q := registers.PopN(f.NbWords) + t := registers.PopN(f.NbWords + 1) ax := xPtr qInv0 := registers.Pop() m := registers.Pop() divShift := f.Define("divShift", 0, func(args ...arm64.Register) { - // m := bi - // f.MUL(t[0], qInv0, m) - // for j=0 to N-1 // (C,t[j-1]) := t[j] + m*q[j] + C for j := 0; j < f.NbWords; j++ { f.MUL(q[j], m, ax) - if j == 0 { - f.ADDS(ax, t[j], t[j]) - } else { - f.ADCS(ax, t[j], t[j]) - } + f.add0m(j)(ax, t[j], t[j]) } - f.ADC("ZR", t[f.NbWords], t[f.NbWords]) + f.add0m(f.NbWords)(t[f.NbWords], "ZR", t[f.NbWords]) // propagate high bits f.UMULH(q[0], m, ax) for j := 1; j <= f.NbWords; j++ { - if j == 1 { - f.ADDS(ax, t[j], t[j-1]) - } else { - f.ADCS(ax, t[j], t[j-1]) - } + f.add1m(j, true)(ax, t[j], t[j-1]) if j != f.NbWords { f.UMULH(q[j], m, ax) } @@ -254,31 +109,18 @@ func (f *FFArm64) generateMul() { // lo bits for j := 0; j < f.NbWords; j++ { f.MUL(a[j], bi, ax) + f.add0m(j)(ax, t[j], t[j]) if j == 0 { - f.ADDS(ax, t[j], t[j]) f.MUL(t[0], qInv0, m) - } else { - f.ADCS(ax, t[j], t[j]) } } - - f.ADC("ZR", "ZR", t[f.NbWords]) + f.add0m(f.NbWords)("ZR", "ZR", t[f.NbWords]) // propagate high bits f.UMULH(a[0], bi, ax) for j := 1; j <= f.NbWords; j++ { - if j == 1 { - f.ADDS(ax, t[j], t[j]) - - } else { - if j == f.NbWords { - - f.ADC(ax, t[j], t[j]) - } else { - f.ADCS(ax, t[j], t[j]) - } - } + f.add1m(j)(ax, t[j], t[j]) if j != f.NbWords { f.UMULH(a[j], bi, ax) } @@ -289,7 +131,6 @@ func (f *FFArm64) generateMul() { mulWord0 := f.Define("MUL_WORD_0", 0, func(args ...arm64.Register) { // for j=0 to N-1 // (C,t[j]) := t[j] + a[j]*b[i] + C - // lo bits for j := 0; j < f.NbWords; j++ { f.MUL(a[j], bi, t[j]) @@ -297,30 +138,15 @@ func (f *FFArm64) generateMul() { // propagate high bits f.UMULH(a[0], bi, ax) - - for j := 1; j <= f.NbWords; j++ { - if j == 1 { - f.ADDS(ax, t[j], t[j]) - - } else { - if j == f.NbWords { - - f.ADC(ax, "ZR", t[j]) - } else { - f.ADCS(ax, t[j], t[j]) - } - } - if j != f.NbWords { - f.UMULH(a[j], bi, ax) - } + for j := 1; j < f.NbWords; j++ { + f.add1m(j)(ax, t[j], t[j]) + f.UMULH(a[j], bi, ax) } + f.add1m(f.NbWords)(ax, "ZR", t[f.NbWords]) f.MUL(t[0], qInv0, m) divShift() }) - f.Comment("mul body") - - // f.LDP("x+8(FP)", xPtr, yPtr) f.MOVD("y+16(FP)", yPtr) f.MOVD("x+8(FP)", xPtr) f.load(xPtr, a) @@ -328,12 +154,11 @@ func (f *FFArm64) generateMul() { f.MOVD(yPtr.At(i), bi) if i == 0 { + // load qInv0 and q at first iteration. f.MOVD(f.qInv0(), qInv0) - for i := 0; i < f.NbWords-1; i += 2 { f.LDP(f.qAt(i), q[i], q[i+1]) } - mulWord0() } else { mulWordN() @@ -354,66 +179,68 @@ func (f *FFArm64) generateMul() { } } - // f.store(t[:f.NbWords], resPtr) - f.RET() } -func (f *FFArm64) reduce(t, q []arm64.Register) { - - if len(t) != f.NbWords || len(q) != f.NbWords { - panic("need 2*nbWords registers") - } - - f.Comment("load modulus and subtract") - +func (f *FFArm64) load(zPtr arm64.Register, z []arm64.Register) { for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(f.qAt(i), q[i], q[i+1]) + f.LDP(zPtr.At(i), z[i], z[i+1]) } +} + +// q must contain the modulus +// q is modified +// t = t mod q (t must be less than 2q) +// t is stored in zPtr +func (f *FFArm64) reduceAndStore(t, q []arm64.Register, zPtr arm64.Register) { + f.Comment("q = t - q") f.SUBS(q[0], t[0], q[0]) for i := 1; i < f.NbWords; i++ { f.SBCS(q[i], t[i], q[i]) } - f.Comment("reduce if necessary") + f.Comment("if no borrow, return q, else return t") for i := 0; i < f.NbWords; i++ { f.CSEL("CS", q[i], t[i], t[i]) + if i%2 == 1 { + f.STP(t[i-1], t[i], zPtr.At(i-1)) + } } } -func (f *FFArm64) load(zPtr arm64.Register, z []arm64.Register) { - for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(zPtr.At(i), z[i], z[i+1]) +func (f *FFArm64) add0n(i int) func(op1, op2, dst interface{}, comment ...string) { + switch { + case i == 0: + return f.ADDS + case i == f.NbWordsLastIndex: + return f.ADC + default: + return f.ADCS } } -func (f *FFArm64) store(z []arm64.Register, zPtr arm64.Register) { - for i := 0; i < f.NbWords-1; i += 2 { - f.STP(z[i], z[i+1], zPtr.At(i)) +func (f *FFArm64) add0m(i int) func(op1, op2, dst interface{}, comment ...string) { + switch { + case i == 0: + return f.ADDS + case i == f.NbWordsLastIndex+1: + return f.ADC + default: + return f.ADCS } } -func (f *FFArm64) reduceAndStore(t, q []arm64.Register, zPtr arm64.Register) { - - if len(t) != f.NbWords || len(q) != f.NbWords { - panic("need 2*nbWords registers") - } - - f.Comment("load modulus and subtract") - - // for i := 0; i < f.NbWords-1; i += 2 { - // f.LDP(f.qAt(i), q[i], q[i+1]) - // } - f.SUBS(q[0], t[0], q[0]) - for i := 1; i < f.NbWords; i++ { - f.SBCS(q[i], t[i], q[i]) - } - - f.Comment("reduce if necessary") - for i := 0; i < f.NbWords; i++ { - f.CSEL("CS", q[i], t[i], t[i]) - if i%2 == 1 { - f.STP(t[i-1], t[i], zPtr.At(i-1)) +func (f *FFArm64) add1m(i int, dumb ...bool) func(op1, op2, dst interface{}, comment ...string) { + switch { + case i == 1: + return f.ADDS + case i == f.NbWordsLastIndex+1: + if len(dumb) == 1 && dumb[0] { + // odd, but it performs better on c8g instances. + return f.ADCS } + return f.ADC + default: + return f.ADCS } } diff --git a/field/generator/generator.go b/field/generator/generator.go index 647e20d225..29ed67accd 100644 --- a/field/generator/generator.go +++ b/field/generator/generator.go @@ -63,6 +63,7 @@ func GenerateFF(F *config.FieldConfig, outputDir, asmDirBuildPath, asmDirInclude "_mul_arm64.s", "_mul_arm64.go", "_ops_amd64.s", + "_ops_arm64.s", "_ops_noasm.go", "_mul_adx_amd64.s", "_ops_amd64.go", diff --git a/field/generator/internal/templates/element/base.go b/field/generator/internal/templates/element/base.go index fae184cd11..dd05519baf 100644 --- a/field/generator/internal/templates/element/base.go +++ b/field/generator/internal/templates/element/base.go @@ -396,8 +396,6 @@ func (z *{{.ElementName}}) fromMont() *{{.ElementName}} { return z } -{{- if not .ASMArm}} - // Add z = x + y (mod q) func (z *{{.ElementName}}) Add( x, y *{{.ElementName}}) *{{.ElementName}} { {{ $hasCarry := or (not $.NoCarry) (gt $.NbWords 1)}} @@ -434,6 +432,7 @@ func (z *{{.ElementName}}) Add( x, y *{{.ElementName}}) *{{.ElementName}} { } + // Double z = x + x (mod q), aka Lsh 1 func (z *{{.ElementName}}) Double( x *{{.ElementName}}) *{{.ElementName}} { {{- if eq .NbWords 1}} @@ -502,7 +501,6 @@ func (z *{{.ElementName}}) Sub( x, y *{{.ElementName}}) *{{.ElementName}} { return z } -{{- end}} // Neg z = q - x func (z *{{.ElementName}}) Neg( x *{{.ElementName}}) *{{.ElementName}} { diff --git a/field/generator/internal/templates/element/ops_asm.go b/field/generator/internal/templates/element/ops_asm.go index 78472adfa5..8fce3add34 100644 --- a/field/generator/internal/templates/element/ops_asm.go +++ b/field/generator/internal/templates/element/ops_asm.go @@ -200,30 +200,6 @@ func (z *{{.ElementName}}) Square(x *{{.ElementName}}) *{{.ElementName}} { const OpsARM64 = ` {{if .ASMArm}} -//go:noescape -func add(res,x,y *{{.ElementName}}) - -func (z *{{.ElementName}}) Add( x, y *{{.ElementName}}) *{{.ElementName}} { - add(z,x,y) - return z -} - -//go:noescape -func double(res, x *{{.ElementName}}) - -func (z *{{.ElementName}}) Double(x *{{.ElementName}}) *{{.ElementName}} { - double(z,x) - return z -} - -//go:noescape -func sub(res,x,y *{{.ElementName}}) - -func (z *{{.ElementName}}) Sub(x, y *{{.ElementName}}) *{{.ElementName}} { - sub(z,x,y) - return z -} - //go:noescape func Butterfly(a, b *{{.ElementName}}) From ea8340bbf47c737f5fb534c5a84a02fcf9f817ec Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 21 Oct 2024 15:24:49 -0500 Subject: [PATCH 16/74] feat: refactor code generation to allow space for arm64 --- .gitignore | 1 + ecc/bls12-377/fp/arith.go | 13 - ecc/bls12-377/fp/asm_adx.go | 1 - ecc/bls12-377/fp/asm_noadx.go | 1 - ...{element_ops_amd64.go => element_amd64.go} | 1 - .../fp/element_amd64.s} | 16 +- ecc/bls12-377/fp/element_arm64.go | 80 ++ .../element_arm64.s} | 16 +- ecc/bls12-377/fp/element_ops_amd64.s | 6 - ecc/bls12-377/fp/element_ops_arm64.s | 6 - ecc/bls12-377/fp/element_purego.go | 704 ++++++++++++++++++ ecc/bls12-377/fp/vector.go | 37 - ecc/bls12-377/fp/vector_purego.go | 54 ++ ecc/bls12-377/fr/arith.go | 13 - ecc/bls12-377/fr/asm_adx.go | 1 - ecc/bls12-377/fr/asm_avx.go | 1 - ecc/bls12-377/fr/asm_noadx.go | 1 - ecc/bls12-377/fr/asm_noavx.go | 1 - .../fr/element_amd64.go} | 1 - .../element_amd64.s} | 16 +- ecc/bls12-377/fr/element_arm64.go | 78 ++ .../fr/element_arm64.s} | 16 +- ecc/bls12-377/fr/element_ops_amd64.s | 6 - ecc/bls12-377/fr/element_ops_arm64.s | 6 - ecc/bls12-377/fr/element_ops_purego.go | 260 ------- ecc/bls12-377/fr/element_purego.go | 402 ++++++++++ .../{element_ops_amd64.go => vector_amd64.go} | 48 -- ecc/bls12-377/fr/vector_purego.go | 56 ++ ecc/bls12-377/internal/fptower/asm.go | 1 - ecc/bls12-377/internal/fptower/asm_noadx.go | 1 - ecc/bls12-377/internal/fptower/e2_fallback.go | 1 - ecc/bls12-381/fp/arith.go | 13 - ecc/bls12-381/fp/asm_adx.go | 1 - ecc/bls12-381/fp/asm_noadx.go | 1 - .../fp/element_amd64.go} | 1 - ecc/bls12-381/fp/element_amd64.s | 21 + ecc/bls12-381/fp/element_arm64.go | 80 ++ ecc/bls12-381/fp/element_arm64.s | 21 + ecc/bls12-381/fp/element_ops_amd64.s | 6 - ecc/bls12-381/fp/element_ops_arm64.s | 6 - ecc/bls12-381/fp/element_purego.go | 704 ++++++++++++++++++ ecc/bls12-381/fp/vector.go | 37 - ecc/bls12-381/fp/vector_purego.go | 54 ++ ecc/bls12-381/fr/arith.go | 13 - ecc/bls12-381/fr/asm_adx.go | 1 - ecc/bls12-381/fr/asm_avx.go | 1 - ecc/bls12-381/fr/asm_noadx.go | 1 - ecc/bls12-381/fr/asm_noavx.go | 1 - .../fr/element_amd64.go} | 1 - ecc/bls12-381/fr/element_amd64.s | 21 + ecc/bls12-381/fr/element_arm64.go | 78 ++ ecc/bls12-381/fr/element_arm64.s | 21 + ecc/bls12-381/fr/element_ops_amd64.s | 6 - ecc/bls12-381/fr/element_ops_arm64.s | 6 - ecc/bls12-381/fr/element_ops_purego.go | 260 ------- ecc/bls12-381/fr/element_purego.go | 402 ++++++++++ .../fr/vector_amd64.go} | 48 -- ecc/bls12-381/fr/vector_purego.go | 56 ++ ecc/bls12-381/internal/fptower/asm.go | 1 - ecc/bls12-381/internal/fptower/asm_noadx.go | 1 - ecc/bls12-381/internal/fptower/e2_fallback.go | 1 - ecc/bls24-315/fp/arith.go | 13 - ecc/bls24-315/fp/asm_adx.go | 1 - ecc/bls24-315/fp/asm_noadx.go | 1 - .../fp/element_amd64.go} | 1 - ecc/bls24-315/fp/element_amd64.s | 21 + ecc/bls24-315/fp/element_ops_amd64.s | 6 - ...lement_ops_purego.go => element_purego.go} | 20 +- ecc/bls24-315/fp/vector.go | 37 - ecc/bls24-315/fp/vector_purego.go | 54 ++ ecc/bls24-315/fr/arith.go | 13 - ecc/bls24-315/fr/asm_adx.go | 1 - ecc/bls24-315/fr/asm_avx.go | 1 - ecc/bls24-315/fr/asm_noadx.go | 1 - ecc/bls24-315/fr/asm_noavx.go | 1 - ecc/bls24-315/fr/element_amd64.go | 66 ++ ecc/bls24-315/fr/element_amd64.s | 21 + ecc/bls24-315/fr/element_arm64.go | 78 ++ ecc/bls24-315/fr/element_arm64.s | 21 + ecc/bls24-315/fr/element_ops_amd64.s | 6 - ecc/bls24-315/fr/element_ops_arm64.go | 31 - ecc/bls24-315/fr/element_ops_arm64.s | 6 - ecc/bls24-315/fr/element_ops_purego.go | 260 ------- ecc/bls24-315/fr/element_purego.go | 402 ++++++++++ .../fr/vector_amd64.go} | 48 -- ecc/bls24-315/fr/vector_purego.go | 56 ++ ecc/bls24-317/fp/arith.go | 13 - ecc/bls24-317/fp/asm_adx.go | 1 - ecc/bls24-317/fp/asm_noadx.go | 1 - .../fp/element_amd64.go} | 1 - ecc/bls24-317/fp/element_amd64.s | 21 + ecc/bls24-317/fp/element_ops_amd64.s | 6 - ...lement_ops_purego.go => element_purego.go} | 20 +- ecc/bls24-317/fp/vector.go | 37 - ecc/bls24-317/fp/vector_purego.go | 54 ++ ecc/bls24-317/fr/arith.go | 13 - ecc/bls24-317/fr/asm_adx.go | 1 - ecc/bls24-317/fr/asm_avx.go | 1 - ecc/bls24-317/fr/asm_noadx.go | 1 - ecc/bls24-317/fr/asm_noavx.go | 1 - ecc/bls24-317/fr/element_amd64.go | 66 ++ ecc/bls24-317/fr/element_amd64.s | 21 + ecc/bls24-317/fr/element_arm64.go | 78 ++ ecc/bls24-317/fr/element_arm64.s | 21 + ecc/bls24-317/fr/element_ops_amd64.s | 6 - ecc/bls24-317/fr/element_ops_arm64.go | 31 - ecc/bls24-317/fr/element_ops_arm64.s | 6 - ecc/bls24-317/fr/element_ops_purego.go | 260 ------- ecc/bls24-317/fr/element_purego.go | 402 ++++++++++ .../{element_ops_amd64.go => vector_amd64.go} | 48 -- ecc/bls24-317/fr/vector_purego.go | 56 ++ ecc/bn254/fp/arith.go | 13 - ecc/bn254/fp/asm_adx.go | 1 - ecc/bn254/fp/asm_avx.go | 1 - ecc/bn254/fp/asm_noadx.go | 1 - ecc/bn254/fp/asm_noavx.go | 1 - ecc/bn254/fp/element_amd64.go | 66 ++ ecc/bn254/fp/element_amd64.s | 21 + ecc/bn254/fp/element_arm64.go | 78 ++ ecc/bn254/fp/element_arm64.s | 21 + ecc/bn254/fp/element_ops_amd64.s | 6 - ecc/bn254/fp/element_ops_arm64.go | 31 - ecc/bn254/fp/element_ops_arm64.s | 6 - ecc/bn254/fp/element_ops_purego.go | 260 ------- .../fp/element_purego.go} | 293 ++++---- .../{element_ops_amd64.go => vector_amd64.go} | 48 -- ecc/bn254/fp/vector_purego.go | 56 ++ ecc/bn254/fr/arith.go | 13 - ecc/bn254/fr/asm_adx.go | 1 - ecc/bn254/fr/asm_avx.go | 1 - ecc/bn254/fr/asm_noadx.go | 1 - ecc/bn254/fr/asm_noavx.go | 1 - ecc/bn254/fr/element_amd64.go | 66 ++ ecc/bn254/fr/element_amd64.s | 21 + ecc/bn254/fr/element_arm64.go | 78 ++ ecc/bn254/fr/element_arm64.s | 21 + ecc/bn254/fr/element_ops_amd64.go | 208 ------ ecc/bn254/fr/element_ops_amd64.s | 6 - ecc/bn254/fr/element_ops_arm64.go | 31 - ecc/bn254/fr/element_ops_arm64.s | 6 - ecc/bn254/fr/element_ops_purego.go | 260 ------- ecc/bn254/fr/element_purego.go | 402 ++++++++++ ecc/bn254/fr/vector_amd64.go | 160 ++++ ecc/bn254/fr/vector_purego.go | 56 ++ ecc/bn254/internal/fptower/asm.go | 1 - ecc/bn254/internal/fptower/asm_noadx.go | 1 - ecc/bn254/internal/fptower/e2_fallback.go | 1 - ecc/bw6-633/fp/arith.go | 13 - ecc/bw6-633/fp/asm_adx.go | 1 - ecc/bw6-633/fp/asm_noadx.go | 1 - ecc/bw6-633/fp/element_amd64.go | 66 ++ ecc/bw6-633/fp/element_amd64.s | 21 + ecc/bw6-633/fp/element_ops_amd64.go | 67 -- ecc/bw6-633/fp/element_ops_amd64.s | 6 - ...lement_ops_purego.go => element_purego.go} | 20 +- ecc/bw6-633/fp/vector.go | 37 - ecc/bw6-633/fp/vector_purego.go | 54 ++ ecc/bw6-633/fr/arith.go | 13 - ecc/bw6-633/fr/asm_adx.go | 1 - ecc/bw6-633/fr/asm_noadx.go | 1 - ecc/bw6-633/fr/element_amd64.go | 66 ++ ecc/bw6-633/fr/element_amd64.s | 21 + ecc/bw6-633/fr/element_ops_amd64.s | 6 - ...lement_ops_purego.go => element_purego.go} | 20 +- ecc/bw6-633/fr/vector.go | 37 - ecc/bw6-633/fr/vector_purego.go | 54 ++ ecc/bw6-761/fp/arith.go | 13 - ecc/bw6-761/fp/asm_adx.go | 1 - ecc/bw6-761/fp/asm_noadx.go | 1 - ecc/bw6-761/fp/element_amd64.go | 66 ++ ecc/bw6-761/fp/element_amd64.s | 21 + ecc/bw6-761/fp/element_ops_amd64.go | 67 -- ecc/bw6-761/fp/element_ops_amd64.s | 6 - ...lement_ops_purego.go => element_purego.go} | 20 +- ecc/bw6-761/fp/vector.go | 37 - ecc/bw6-761/fp/vector_purego.go | 54 ++ ecc/bw6-761/fr/arith.go | 13 - ecc/bw6-761/fr/asm_adx.go | 1 - ecc/bw6-761/fr/asm_noadx.go | 1 - ecc/bw6-761/fr/element_amd64.go | 66 ++ ecc/bw6-761/fr/element_amd64.s | 21 + ecc/bw6-761/fr/element_arm64.go | 80 ++ ecc/bw6-761/fr/element_arm64.s | 21 + ecc/bw6-761/fr/element_ops_amd64.s | 6 - ecc/bw6-761/fr/element_ops_arm64.go | 31 - ecc/bw6-761/fr/element_ops_arm64.s | 6 - ecc/bw6-761/fr/element_purego.go | 704 ++++++++++++++++++ ecc/bw6-761/fr/vector.go | 37 - ecc/bw6-761/fr/vector_purego.go | 54 ++ ...lement_ops_purego.go => element_purego.go} | 17 +- ecc/secp256k1/fp/vector.go | 37 - ecc/secp256k1/fp/vector_purego.go | 54 ++ ...lement_ops_purego.go => element_purego.go} | 17 +- ecc/secp256k1/fr/vector.go | 37 - ecc/secp256k1/fr/vector_purego.go | 54 ++ ecc/stark-curve/fp/arith.go | 13 - ecc/stark-curve/fp/asm_adx.go | 1 - ecc/stark-curve/fp/asm_avx.go | 1 - ecc/stark-curve/fp/asm_noadx.go | 1 - ecc/stark-curve/fp/asm_noavx.go | 1 - ecc/stark-curve/fp/element_amd64.go | 66 ++ ecc/stark-curve/fp/element_amd64.s | 21 + ecc/stark-curve/fp/element_arm64.go | 78 ++ ecc/stark-curve/fp/element_arm64.s | 21 + ecc/stark-curve/fp/element_ops_amd64.s | 6 - ecc/stark-curve/fp/element_ops_arm64.go | 31 - ecc/stark-curve/fp/element_ops_arm64.s | 6 - ecc/stark-curve/fp/element_ops_purego.go | 260 ------- .../fp/element_purego.go} | 293 ++++---- .../{element_ops_amd64.go => vector_amd64.go} | 48 -- ecc/stark-curve/fp/vector_purego.go | 56 ++ ecc/stark-curve/fr/arith.go | 13 - ecc/stark-curve/fr/asm_adx.go | 1 - ecc/stark-curve/fr/asm_avx.go | 1 - ecc/stark-curve/fr/asm_noadx.go | 1 - ecc/stark-curve/fr/asm_noavx.go | 1 - ecc/stark-curve/fr/element_amd64.go | 66 ++ ecc/stark-curve/fr/element_amd64.s | 21 + ecc/stark-curve/fr/element_arm64.go | 78 ++ ecc/stark-curve/fr/element_arm64.s | 21 + ecc/stark-curve/fr/element_ops_amd64.go | 208 ------ ecc/stark-curve/fr/element_ops_amd64.s | 6 - ecc/stark-curve/fr/element_ops_arm64.go | 31 - ecc/stark-curve/fr/element_ops_arm64.s | 6 - ecc/stark-curve/fr/element_ops_purego.go | 260 ------- .../fr/element_purego.go} | 293 ++++---- ecc/stark-curve/fr/vector_amd64.go | 160 ++++ ecc/stark-curve/fr/vector_purego.go | 56 ++ field/generator/asm/arm64/build.go | 1 - field/generator/config/field_config.go | 16 +- field/generator/generator.go | 387 ++++------ .../internal/templates/element/arith.go | 16 - .../internal/templates/element/ops_asm.go | 214 ++---- .../internal/templates/element/ops_purego.go | 61 +- .../internal/templates/element/vector.go | 43 -- .../templates/element/vector_ops_asm.go | 147 ++++ .../templates/element/vector_ops_purego.go | 40 + field/goff/cmd/root.go | 2 +- ...lement_ops_purego.go => element_purego.go} | 17 +- field/goldilocks/vector.go | 37 - field/goldilocks/vector_purego.go | 54 ++ go.mod | 2 +- go.sum | 2 + internal/generator/main.go | 8 +- 244 files changed, 8489 insertions(+), 5011 deletions(-) rename ecc/bls12-377/fp/{element_ops_amd64.go => element_amd64.go} (98%) rename ecc/{bls12-381/fp/element_ops_arm64.go => bls12-377/fp/element_amd64.s} (75%) create mode 100644 ecc/bls12-377/fp/element_arm64.go rename ecc/bls12-377/{fr/element_ops_arm64.go => fp/element_arm64.s} (75%) delete mode 100644 ecc/bls12-377/fp/element_ops_amd64.s delete mode 100644 ecc/bls12-377/fp/element_ops_arm64.s create mode 100644 ecc/bls12-377/fp/element_purego.go create mode 100644 ecc/bls12-377/fp/vector_purego.go rename ecc/{bw6-633/fr/element_ops_amd64.go => bls12-377/fr/element_amd64.go} (98%) rename ecc/bls12-377/{fp/element_ops_arm64.go => fr/element_amd64.s} (75%) create mode 100644 ecc/bls12-377/fr/element_arm64.go rename ecc/{bls12-381/fr/element_ops_arm64.go => bls12-377/fr/element_arm64.s} (75%) delete mode 100644 ecc/bls12-377/fr/element_ops_amd64.s delete mode 100644 ecc/bls12-377/fr/element_ops_arm64.s delete mode 100644 ecc/bls12-377/fr/element_ops_purego.go create mode 100644 ecc/bls12-377/fr/element_purego.go rename ecc/bls12-377/fr/{element_ops_amd64.go => vector_amd64.go} (84%) create mode 100644 ecc/bls12-377/fr/vector_purego.go rename ecc/{bls24-315/fp/element_ops_amd64.go => bls12-381/fp/element_amd64.go} (98%) create mode 100644 ecc/bls12-381/fp/element_amd64.s create mode 100644 ecc/bls12-381/fp/element_arm64.go create mode 100644 ecc/bls12-381/fp/element_arm64.s delete mode 100644 ecc/bls12-381/fp/element_ops_amd64.s delete mode 100644 ecc/bls12-381/fp/element_ops_arm64.s create mode 100644 ecc/bls12-381/fp/element_purego.go create mode 100644 ecc/bls12-381/fp/vector_purego.go rename ecc/{bw6-761/fr/element_ops_amd64.go => bls12-381/fr/element_amd64.go} (98%) create mode 100644 ecc/bls12-381/fr/element_amd64.s create mode 100644 ecc/bls12-381/fr/element_arm64.go create mode 100644 ecc/bls12-381/fr/element_arm64.s delete mode 100644 ecc/bls12-381/fr/element_ops_amd64.s delete mode 100644 ecc/bls12-381/fr/element_ops_arm64.s delete mode 100644 ecc/bls12-381/fr/element_ops_purego.go create mode 100644 ecc/bls12-381/fr/element_purego.go rename ecc/{bls24-315/fr/element_ops_amd64.go => bls12-381/fr/vector_amd64.go} (84%) create mode 100644 ecc/bls12-381/fr/vector_purego.go rename ecc/{bls24-317/fp/element_ops_amd64.go => bls24-315/fp/element_amd64.go} (98%) create mode 100644 ecc/bls24-315/fp/element_amd64.s delete mode 100644 ecc/bls24-315/fp/element_ops_amd64.s rename ecc/bls24-315/fp/{element_ops_purego.go => element_purego.go} (99%) create mode 100644 ecc/bls24-315/fp/vector_purego.go create mode 100644 ecc/bls24-315/fr/element_amd64.go create mode 100644 ecc/bls24-315/fr/element_amd64.s create mode 100644 ecc/bls24-315/fr/element_arm64.go create mode 100644 ecc/bls24-315/fr/element_arm64.s delete mode 100644 ecc/bls24-315/fr/element_ops_amd64.s delete mode 100644 ecc/bls24-315/fr/element_ops_arm64.go delete mode 100644 ecc/bls24-315/fr/element_ops_arm64.s delete mode 100644 ecc/bls24-315/fr/element_ops_purego.go create mode 100644 ecc/bls24-315/fr/element_purego.go rename ecc/{bls12-381/fr/element_ops_amd64.go => bls24-315/fr/vector_amd64.go} (84%) create mode 100644 ecc/bls24-315/fr/vector_purego.go rename ecc/{bls12-381/fp/element_ops_amd64.go => bls24-317/fp/element_amd64.go} (98%) create mode 100644 ecc/bls24-317/fp/element_amd64.s delete mode 100644 ecc/bls24-317/fp/element_ops_amd64.s rename ecc/bls24-317/fp/{element_ops_purego.go => element_purego.go} (99%) create mode 100644 ecc/bls24-317/fp/vector_purego.go create mode 100644 ecc/bls24-317/fr/element_amd64.go create mode 100644 ecc/bls24-317/fr/element_amd64.s create mode 100644 ecc/bls24-317/fr/element_arm64.go create mode 100644 ecc/bls24-317/fr/element_arm64.s delete mode 100644 ecc/bls24-317/fr/element_ops_amd64.s delete mode 100644 ecc/bls24-317/fr/element_ops_arm64.go delete mode 100644 ecc/bls24-317/fr/element_ops_arm64.s delete mode 100644 ecc/bls24-317/fr/element_ops_purego.go create mode 100644 ecc/bls24-317/fr/element_purego.go rename ecc/bls24-317/fr/{element_ops_amd64.go => vector_amd64.go} (84%) create mode 100644 ecc/bls24-317/fr/vector_purego.go create mode 100644 ecc/bn254/fp/element_amd64.go create mode 100644 ecc/bn254/fp/element_amd64.s create mode 100644 ecc/bn254/fp/element_arm64.go create mode 100644 ecc/bn254/fp/element_arm64.s delete mode 100644 ecc/bn254/fp/element_ops_amd64.s delete mode 100644 ecc/bn254/fp/element_ops_arm64.go delete mode 100644 ecc/bn254/fp/element_ops_arm64.s delete mode 100644 ecc/bn254/fp/element_ops_purego.go rename ecc/{bls12-377/fp/element_ops_purego.go => bn254/fp/element_purego.go} (59%) rename ecc/bn254/fp/{element_ops_amd64.go => vector_amd64.go} (84%) create mode 100644 ecc/bn254/fp/vector_purego.go create mode 100644 ecc/bn254/fr/element_amd64.go create mode 100644 ecc/bn254/fr/element_amd64.s create mode 100644 ecc/bn254/fr/element_arm64.go create mode 100644 ecc/bn254/fr/element_arm64.s delete mode 100644 ecc/bn254/fr/element_ops_amd64.go delete mode 100644 ecc/bn254/fr/element_ops_amd64.s delete mode 100644 ecc/bn254/fr/element_ops_arm64.go delete mode 100644 ecc/bn254/fr/element_ops_arm64.s delete mode 100644 ecc/bn254/fr/element_ops_purego.go create mode 100644 ecc/bn254/fr/element_purego.go create mode 100644 ecc/bn254/fr/vector_amd64.go create mode 100644 ecc/bn254/fr/vector_purego.go create mode 100644 ecc/bw6-633/fp/element_amd64.go create mode 100644 ecc/bw6-633/fp/element_amd64.s delete mode 100644 ecc/bw6-633/fp/element_ops_amd64.go delete mode 100644 ecc/bw6-633/fp/element_ops_amd64.s rename ecc/bw6-633/fp/{element_ops_purego.go => element_purego.go} (99%) create mode 100644 ecc/bw6-633/fp/vector_purego.go create mode 100644 ecc/bw6-633/fr/element_amd64.go create mode 100644 ecc/bw6-633/fr/element_amd64.s delete mode 100644 ecc/bw6-633/fr/element_ops_amd64.s rename ecc/bw6-633/fr/{element_ops_purego.go => element_purego.go} (99%) create mode 100644 ecc/bw6-633/fr/vector_purego.go create mode 100644 ecc/bw6-761/fp/element_amd64.go create mode 100644 ecc/bw6-761/fp/element_amd64.s delete mode 100644 ecc/bw6-761/fp/element_ops_amd64.go delete mode 100644 ecc/bw6-761/fp/element_ops_amd64.s rename ecc/bw6-761/fp/{element_ops_purego.go => element_purego.go} (99%) create mode 100644 ecc/bw6-761/fp/vector_purego.go create mode 100644 ecc/bw6-761/fr/element_amd64.go create mode 100644 ecc/bw6-761/fr/element_amd64.s create mode 100644 ecc/bw6-761/fr/element_arm64.go create mode 100644 ecc/bw6-761/fr/element_arm64.s delete mode 100644 ecc/bw6-761/fr/element_ops_amd64.s delete mode 100644 ecc/bw6-761/fr/element_ops_arm64.go delete mode 100644 ecc/bw6-761/fr/element_ops_arm64.s create mode 100644 ecc/bw6-761/fr/element_purego.go create mode 100644 ecc/bw6-761/fr/vector_purego.go rename ecc/secp256k1/fp/{element_ops_purego.go => element_purego.go} (99%) create mode 100644 ecc/secp256k1/fp/vector_purego.go rename ecc/secp256k1/fr/{element_ops_purego.go => element_purego.go} (99%) create mode 100644 ecc/secp256k1/fr/vector_purego.go create mode 100644 ecc/stark-curve/fp/element_amd64.go create mode 100644 ecc/stark-curve/fp/element_amd64.s create mode 100644 ecc/stark-curve/fp/element_arm64.go create mode 100644 ecc/stark-curve/fp/element_arm64.s delete mode 100644 ecc/stark-curve/fp/element_ops_amd64.s delete mode 100644 ecc/stark-curve/fp/element_ops_arm64.go delete mode 100644 ecc/stark-curve/fp/element_ops_arm64.s delete mode 100644 ecc/stark-curve/fp/element_ops_purego.go rename ecc/{bls12-381/fp/element_ops_purego.go => stark-curve/fp/element_purego.go} (59%) rename ecc/stark-curve/fp/{element_ops_amd64.go => vector_amd64.go} (84%) create mode 100644 ecc/stark-curve/fp/vector_purego.go create mode 100644 ecc/stark-curve/fr/element_amd64.go create mode 100644 ecc/stark-curve/fr/element_amd64.s create mode 100644 ecc/stark-curve/fr/element_arm64.go create mode 100644 ecc/stark-curve/fr/element_arm64.s delete mode 100644 ecc/stark-curve/fr/element_ops_amd64.go delete mode 100644 ecc/stark-curve/fr/element_ops_amd64.s delete mode 100644 ecc/stark-curve/fr/element_ops_arm64.go delete mode 100644 ecc/stark-curve/fr/element_ops_arm64.s delete mode 100644 ecc/stark-curve/fr/element_ops_purego.go rename ecc/{bw6-761/fr/element_ops_purego.go => stark-curve/fr/element_purego.go} (59%) create mode 100644 ecc/stark-curve/fr/vector_amd64.go create mode 100644 ecc/stark-curve/fr/vector_purego.go create mode 100644 field/generator/internal/templates/element/vector_ops_asm.go create mode 100644 field/generator/internal/templates/element/vector_ops_purego.go rename field/goldilocks/{element_ops_purego.go => element_purego.go} (99%) create mode 100644 field/goldilocks/vector_purego.go diff --git a/.gitignore b/.gitignore index bb95194a5f..4d9d603f74 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ *.so *.dylib *.txt +**/.rbench.lock # Test binary, build with `go test -c` *.test diff --git a/ecc/bls12-377/fp/arith.go b/ecc/bls12-377/fp/arith.go index 6f281563b3..66fa667482 100644 --- a/ecc/bls12-377/fp/arith.go +++ b/ecc/bls12-377/fp/arith.go @@ -58,16 +58,3 @@ func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { hi, _ = bits.Add64(hi, e, carry) return } -func max(a int, b int) int { - if a > b { - return a - } - return b -} - -func min(a int, b int) int { - if a < b { - return a - } - return b -} diff --git a/ecc/bls12-377/fp/asm_adx.go b/ecc/bls12-377/fp/asm_adx.go index 0481989ec6..f8e29bd1a7 100644 --- a/ecc/bls12-377/fp/asm_adx.go +++ b/ecc/bls12-377/fp/asm_adx.go @@ -1,5 +1,4 @@ //go:build !noadx -// +build !noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-377/fp/asm_noadx.go b/ecc/bls12-377/fp/asm_noadx.go index 92f8cc0f42..cb6cfa0f50 100644 --- a/ecc/bls12-377/fp/asm_noadx.go +++ b/ecc/bls12-377/fp/asm_noadx.go @@ -1,5 +1,4 @@ //go:build noadx -// +build noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-377/fp/element_ops_amd64.go b/ecc/bls12-377/fp/element_amd64.go similarity index 98% rename from ecc/bls12-377/fp/element_ops_amd64.go rename to ecc/bls12-377/fp/element_amd64.go index ed2803d717..77a51ee25e 100644 --- a/ecc/bls12-377/fp/element_ops_amd64.go +++ b/ecc/bls12-377/fp/element_amd64.go @@ -1,5 +1,4 @@ //go:build !purego -// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-381/fp/element_ops_arm64.go b/ecc/bls12-377/fp/element_amd64.s similarity index 75% rename from ecc/bls12-381/fp/element_ops_arm64.go rename to ecc/bls12-377/fp/element_amd64.s index 78ae87b96b..872eddf5d6 100644 --- a/ecc/bls12-381/fp/element_ops_arm64.go +++ b/ecc/bls12-377/fp/element_amd64.s @@ -1,5 +1,4 @@ -//go:build !purego -// +build !purego +//go:build !purego // Copyright 2020 ConsenSys Software Inc. // @@ -17,15 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -package fp +// We include the hash to force the Go compiler to recompile: 11124594824487954849 +#include "../../../field/asm/element_6w_amd64.s" -//go:noescape -func Butterfly(a, b *Element) - -//go:noescape -func mul(res, x, y *Element) - -func (z *Element) Mul(x, y *Element) *Element { - mul(z, x, y) - return z -} diff --git a/ecc/bls12-377/fp/element_arm64.go b/ecc/bls12-377/fp/element_arm64.go new file mode 100644 index 0000000000..9c9a211f27 --- /dev/null +++ b/ecc/bls12-377/fp/element_arm64.go @@ -0,0 +1,80 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 1176283927673829444, + 14130787773971430395, + 11354866436980285261, + 15740727779991009548, + 14951814113394531041, + 33013799364667434, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} diff --git a/ecc/bls12-377/fr/element_ops_arm64.go b/ecc/bls12-377/fp/element_arm64.s similarity index 75% rename from ecc/bls12-377/fr/element_ops_arm64.go rename to ecc/bls12-377/fp/element_arm64.s index 6759e524eb..62de3f0be7 100644 --- a/ecc/bls12-377/fr/element_ops_arm64.go +++ b/ecc/bls12-377/fp/element_arm64.s @@ -1,5 +1,4 @@ -//go:build !purego -// +build !purego +//go:build !purego // Copyright 2020 ConsenSys Software Inc. // @@ -17,15 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -package fr +// We include the hash to force the Go compiler to recompile: 4799084555005768587 +#include "../../../field/asm/element_6w_arm64.s" -//go:noescape -func Butterfly(a, b *Element) - -//go:noescape -func mul(res, x, y *Element) - -func (z *Element) Mul(x, y *Element) *Element { - mul(z, x, y) - return z -} diff --git a/ecc/bls12-377/fp/element_ops_amd64.s b/ecc/bls12-377/fp/element_ops_amd64.s deleted file mode 100644 index cabff26f70..0000000000 --- a/ecc/bls12-377/fp/element_ops_amd64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 11124594824487954849 -#include "../../../field/asm/element_6w_amd64.s" - diff --git a/ecc/bls12-377/fp/element_ops_arm64.s b/ecc/bls12-377/fp/element_ops_arm64.s deleted file mode 100644 index f12adf4dc5..0000000000 --- a/ecc/bls12-377/fp/element_ops_arm64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 4799084555005768587 -#include "../../../field/asm/element_6w_arm64.s" - diff --git a/ecc/bls12-377/fp/element_purego.go b/ecc/bls12-377/fp/element_purego.go new file mode 100644 index 0000000000..be2027de13 --- /dev/null +++ b/ecc/bls12-377/fp/element_purego.go @@ -0,0 +1,704 @@ +//go:build purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 1176283927673829444, + 14130787773971430395, + 11354866436980285261, + 15740727779991009548, + 14951814113394531041, + 33013799364667434, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + var t0, t1, t2, t3, t4, t5 uint64 + var u0, u1, u2, u3, u4, u5 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + u4, t4 = bits.Mul64(v, y[4]) + u5, t5 = bits.Mul64(v, y[5]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3, t4, t5 uint64 + var u0, u1, u2, u3, u4, u5 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + u4, t4 = bits.Mul64(v, x[4]) + u5, t5 = bits.Mul64(v, x[5]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/bls12-377/fp/vector.go b/ecc/bls12-377/fp/vector.go index f1d659e767..64228605a8 100644 --- a/ecc/bls12-377/fp/vector.go +++ b/ecc/bls12-377/fp/vector.go @@ -201,43 +201,6 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") diff --git a/ecc/bls12-377/fp/vector_purego.go b/ecc/bls12-377/fp/vector_purego.go new file mode 100644 index 0000000000..798b669887 --- /dev/null +++ b/ecc/bls12-377/fp/vector_purego.go @@ -0,0 +1,54 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/ecc/bls12-377/fr/arith.go b/ecc/bls12-377/fr/arith.go index 7cfd55da19..83c9fd9ef9 100644 --- a/ecc/bls12-377/fr/arith.go +++ b/ecc/bls12-377/fr/arith.go @@ -58,16 +58,3 @@ func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { hi, _ = bits.Add64(hi, e, carry) return } -func max(a int, b int) int { - if a > b { - return a - } - return b -} - -func min(a int, b int) int { - if a < b { - return a - } - return b -} diff --git a/ecc/bls12-377/fr/asm_adx.go b/ecc/bls12-377/fr/asm_adx.go index da061913ba..9273ea23ab 100644 --- a/ecc/bls12-377/fr/asm_adx.go +++ b/ecc/bls12-377/fr/asm_adx.go @@ -1,5 +1,4 @@ //go:build !noadx -// +build !noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-377/fr/asm_avx.go b/ecc/bls12-377/fr/asm_avx.go index 955f559799..1cc06c6e8d 100644 --- a/ecc/bls12-377/fr/asm_avx.go +++ b/ecc/bls12-377/fr/asm_avx.go @@ -1,5 +1,4 @@ //go:build !noavx -// +build !noavx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-377/fr/asm_noadx.go b/ecc/bls12-377/fr/asm_noadx.go index 7f52ffa197..b784a24247 100644 --- a/ecc/bls12-377/fr/asm_noadx.go +++ b/ecc/bls12-377/fr/asm_noadx.go @@ -1,5 +1,4 @@ //go:build noadx -// +build noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-377/fr/asm_noavx.go b/ecc/bls12-377/fr/asm_noavx.go index e5a5b1f2cc..66bfc00772 100644 --- a/ecc/bls12-377/fr/asm_noavx.go +++ b/ecc/bls12-377/fr/asm_noavx.go @@ -1,5 +1,4 @@ //go:build noavx -// +build noavx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-633/fr/element_ops_amd64.go b/ecc/bls12-377/fr/element_amd64.go similarity index 98% rename from ecc/bw6-633/fr/element_ops_amd64.go rename to ecc/bls12-377/fr/element_amd64.go index 83d40c28c1..0ddb905f7b 100644 --- a/ecc/bw6-633/fr/element_ops_amd64.go +++ b/ecc/bls12-377/fr/element_amd64.go @@ -1,5 +1,4 @@ //go:build !purego -// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-377/fp/element_ops_arm64.go b/ecc/bls12-377/fr/element_amd64.s similarity index 75% rename from ecc/bls12-377/fp/element_ops_arm64.go rename to ecc/bls12-377/fr/element_amd64.s index 78ae87b96b..fb00194d7e 100644 --- a/ecc/bls12-377/fp/element_ops_arm64.go +++ b/ecc/bls12-377/fr/element_amd64.s @@ -1,5 +1,4 @@ -//go:build !purego -// +build !purego +//go:build !purego // Copyright 2020 ConsenSys Software Inc. // @@ -17,15 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -package fp +// We include the hash to force the Go compiler to recompile: 9425145785761608449 +#include "../../../field/asm/element_4w_amd64.s" -//go:noescape -func Butterfly(a, b *Element) - -//go:noescape -func mul(res, x, y *Element) - -func (z *Element) Mul(x, y *Element) *Element { - mul(z, x, y) - return z -} diff --git a/ecc/bls12-377/fr/element_arm64.go b/ecc/bls12-377/fr/element_arm64.go new file mode 100644 index 0000000000..07888f58ba --- /dev/null +++ b/ecc/bls12-377/fr/element_arm64.go @@ -0,0 +1,78 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 18434640649710993230, + 12067750152132099910, + 14024878721438555919, + 347766975729306096, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} diff --git a/ecc/bls12-381/fr/element_ops_arm64.go b/ecc/bls12-377/fr/element_arm64.s similarity index 75% rename from ecc/bls12-381/fr/element_ops_arm64.go rename to ecc/bls12-377/fr/element_arm64.s index 6759e524eb..3cd8aaa667 100644 --- a/ecc/bls12-381/fr/element_ops_arm64.go +++ b/ecc/bls12-377/fr/element_arm64.s @@ -1,5 +1,4 @@ -//go:build !purego -// +build !purego +//go:build !purego // Copyright 2020 ConsenSys Software Inc. // @@ -17,15 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -package fr +// We include the hash to force the Go compiler to recompile: 18027907654287790676 +#include "../../../field/asm/element_4w_arm64.s" -//go:noescape -func Butterfly(a, b *Element) - -//go:noescape -func mul(res, x, y *Element) - -func (z *Element) Mul(x, y *Element) *Element { - mul(z, x, y) - return z -} diff --git a/ecc/bls12-377/fr/element_ops_amd64.s b/ecc/bls12-377/fr/element_ops_amd64.s deleted file mode 100644 index 6c42136a7a..0000000000 --- a/ecc/bls12-377/fr/element_ops_amd64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9425145785761608449 -#include "../../../field/asm/element_4w_amd64.s" - diff --git a/ecc/bls12-377/fr/element_ops_arm64.s b/ecc/bls12-377/fr/element_ops_arm64.s deleted file mode 100644 index 6ba54c61aa..0000000000 --- a/ecc/bls12-377/fr/element_ops_arm64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 18027907654287790676 -#include "../../../field/asm/element_4w_arm64.s" - diff --git a/ecc/bls12-377/fr/element_ops_purego.go b/ecc/bls12-377/fr/element_ops_purego.go deleted file mode 100644 index 6108a57091..0000000000 --- a/ecc/bls12-377/fr/element_ops_purego.go +++ /dev/null @@ -1,260 +0,0 @@ -//go:build !amd64 || purego -// +build !amd64 purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -import "math/bits" - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 18434640649710993230, - 12067750152132099910, - 14024878721438555919, - 347766975729306096, - } - x.Mul(x, &y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} - -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - - var t0, t1, t2, t3 uint64 - var u0, u1, u2, u3 uint64 - { - var c0, c1, c2 uint64 - v := x[0] - u0, t0 = bits.Mul64(v, x[0]) - u1, t1 = bits.Mul64(v, x[1]) - u2, t2 = bits.Mul64(v, x[2]) - u3, t3 = bits.Mul64(v, x[3]) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, 0, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[1] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[2] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[3] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - z[0] = t0 - z[1] = t1 - z[2] = t2 - z[3] = t3 - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} diff --git a/ecc/bls12-377/fr/element_purego.go b/ecc/bls12-377/fr/element_purego.go new file mode 100644 index 0000000000..aa0c3e9a75 --- /dev/null +++ b/ecc/bls12-377/fr/element_purego.go @@ -0,0 +1,402 @@ +//go:build purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 18434640649710993230, + 12067750152132099910, + 14024878721438555919, + 347766975729306096, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/bls12-377/fr/element_ops_amd64.go b/ecc/bls12-377/fr/vector_amd64.go similarity index 84% rename from ecc/bls12-377/fr/element_ops_amd64.go rename to ecc/bls12-377/fr/vector_amd64.go index b653e80069..0164ecb382 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.go +++ b/ecc/bls12-377/fr/vector_amd64.go @@ -1,5 +1,4 @@ //go:build !purego -// +build !purego // Copyright 2020 ConsenSys Software Inc. // @@ -19,32 +18,6 @@ package fr -//go:noescape -func MulBy3(x *Element) - -//go:noescape -func MulBy5(x *Element) - -//go:noescape -func MulBy13(x *Element) - -//go:noescape -func mul(res, x, y *Element) - -//go:noescape -func fromMont(res *Element) - -//go:noescape -func reduce(res *Element) - -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -// -//go:noescape -func Butterfly(a, b *Element) - // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -185,24 +158,3 @@ var ( //go:noescape func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) - -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for doc. - mul(z, x, x) - return z -} diff --git a/ecc/bls12-377/fr/vector_purego.go b/ecc/bls12-377/fr/vector_purego.go new file mode 100644 index 0000000000..d09c259806 --- /dev/null +++ b/ecc/bls12-377/fr/vector_purego.go @@ -0,0 +1,56 @@ +//go:build purego || arm64 + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/ecc/bls12-377/internal/fptower/asm.go b/ecc/bls12-377/internal/fptower/asm.go index 49751a9396..03b1160807 100644 --- a/ecc/bls12-377/internal/fptower/asm.go +++ b/ecc/bls12-377/internal/fptower/asm.go @@ -1,5 +1,4 @@ //go:build !noadx -// +build !noadx // Copyright 2020 Consensys Software Inc. // diff --git a/ecc/bls12-377/internal/fptower/asm_noadx.go b/ecc/bls12-377/internal/fptower/asm_noadx.go index c6a97081fc..ea7782392c 100644 --- a/ecc/bls12-377/internal/fptower/asm_noadx.go +++ b/ecc/bls12-377/internal/fptower/asm_noadx.go @@ -1,5 +1,4 @@ //go:build noadx -// +build noadx // Copyright 2020 Consensys Software Inc. // diff --git a/ecc/bls12-377/internal/fptower/e2_fallback.go b/ecc/bls12-377/internal/fptower/e2_fallback.go index 6fe47c4111..1b6011564f 100644 --- a/ecc/bls12-377/internal/fptower/e2_fallback.go +++ b/ecc/bls12-377/internal/fptower/e2_fallback.go @@ -1,5 +1,4 @@ //go:build !amd64 -// +build !amd64 // Copyright 2020 Consensys Software Inc. // diff --git a/ecc/bls12-381/fp/arith.go b/ecc/bls12-381/fp/arith.go index 6f281563b3..66fa667482 100644 --- a/ecc/bls12-381/fp/arith.go +++ b/ecc/bls12-381/fp/arith.go @@ -58,16 +58,3 @@ func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { hi, _ = bits.Add64(hi, e, carry) return } -func max(a int, b int) int { - if a > b { - return a - } - return b -} - -func min(a int, b int) int { - if a < b { - return a - } - return b -} diff --git a/ecc/bls12-381/fp/asm_adx.go b/ecc/bls12-381/fp/asm_adx.go index 0481989ec6..f8e29bd1a7 100644 --- a/ecc/bls12-381/fp/asm_adx.go +++ b/ecc/bls12-381/fp/asm_adx.go @@ -1,5 +1,4 @@ //go:build !noadx -// +build !noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-381/fp/asm_noadx.go b/ecc/bls12-381/fp/asm_noadx.go index 92f8cc0f42..cb6cfa0f50 100644 --- a/ecc/bls12-381/fp/asm_noadx.go +++ b/ecc/bls12-381/fp/asm_noadx.go @@ -1,5 +1,4 @@ //go:build noadx -// +build noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-315/fp/element_ops_amd64.go b/ecc/bls12-381/fp/element_amd64.go similarity index 98% rename from ecc/bls24-315/fp/element_ops_amd64.go rename to ecc/bls12-381/fp/element_amd64.go index ed2803d717..77a51ee25e 100644 --- a/ecc/bls24-315/fp/element_ops_amd64.go +++ b/ecc/bls12-381/fp/element_amd64.go @@ -1,5 +1,4 @@ //go:build !purego -// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-381/fp/element_amd64.s b/ecc/bls12-381/fp/element_amd64.s new file mode 100644 index 0000000000..872eddf5d6 --- /dev/null +++ b/ecc/bls12-381/fp/element_amd64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 11124594824487954849 +#include "../../../field/asm/element_6w_amd64.s" + diff --git a/ecc/bls12-381/fp/element_arm64.go b/ecc/bls12-381/fp/element_arm64.go new file mode 100644 index 0000000000..aed4105795 --- /dev/null +++ b/ecc/bls12-381/fp/element_arm64.go @@ -0,0 +1,80 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 13438459813099623723, + 14459933216667336738, + 14900020990258308116, + 2941282712809091851, + 13639094935183769893, + 1835248516986607988, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} diff --git a/ecc/bls12-381/fp/element_arm64.s b/ecc/bls12-381/fp/element_arm64.s new file mode 100644 index 0000000000..62de3f0be7 --- /dev/null +++ b/ecc/bls12-381/fp/element_arm64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 4799084555005768587 +#include "../../../field/asm/element_6w_arm64.s" + diff --git a/ecc/bls12-381/fp/element_ops_amd64.s b/ecc/bls12-381/fp/element_ops_amd64.s deleted file mode 100644 index cabff26f70..0000000000 --- a/ecc/bls12-381/fp/element_ops_amd64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 11124594824487954849 -#include "../../../field/asm/element_6w_amd64.s" - diff --git a/ecc/bls12-381/fp/element_ops_arm64.s b/ecc/bls12-381/fp/element_ops_arm64.s deleted file mode 100644 index f12adf4dc5..0000000000 --- a/ecc/bls12-381/fp/element_ops_arm64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 4799084555005768587 -#include "../../../field/asm/element_6w_arm64.s" - diff --git a/ecc/bls12-381/fp/element_purego.go b/ecc/bls12-381/fp/element_purego.go new file mode 100644 index 0000000000..511c52f1f1 --- /dev/null +++ b/ecc/bls12-381/fp/element_purego.go @@ -0,0 +1,704 @@ +//go:build purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 13438459813099623723, + 14459933216667336738, + 14900020990258308116, + 2941282712809091851, + 13639094935183769893, + 1835248516986607988, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + var t0, t1, t2, t3, t4, t5 uint64 + var u0, u1, u2, u3, u4, u5 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + u4, t4 = bits.Mul64(v, y[4]) + u5, t5 = bits.Mul64(v, y[5]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3, t4, t5 uint64 + var u0, u1, u2, u3, u4, u5 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + u4, t4 = bits.Mul64(v, x[4]) + u5, t5 = bits.Mul64(v, x[5]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/bls12-381/fp/vector.go b/ecc/bls12-381/fp/vector.go index f1d659e767..64228605a8 100644 --- a/ecc/bls12-381/fp/vector.go +++ b/ecc/bls12-381/fp/vector.go @@ -201,43 +201,6 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") diff --git a/ecc/bls12-381/fp/vector_purego.go b/ecc/bls12-381/fp/vector_purego.go new file mode 100644 index 0000000000..798b669887 --- /dev/null +++ b/ecc/bls12-381/fp/vector_purego.go @@ -0,0 +1,54 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/ecc/bls12-381/fr/arith.go b/ecc/bls12-381/fr/arith.go index 7cfd55da19..83c9fd9ef9 100644 --- a/ecc/bls12-381/fr/arith.go +++ b/ecc/bls12-381/fr/arith.go @@ -58,16 +58,3 @@ func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { hi, _ = bits.Add64(hi, e, carry) return } -func max(a int, b int) int { - if a > b { - return a - } - return b -} - -func min(a int, b int) int { - if a < b { - return a - } - return b -} diff --git a/ecc/bls12-381/fr/asm_adx.go b/ecc/bls12-381/fr/asm_adx.go index da061913ba..9273ea23ab 100644 --- a/ecc/bls12-381/fr/asm_adx.go +++ b/ecc/bls12-381/fr/asm_adx.go @@ -1,5 +1,4 @@ //go:build !noadx -// +build !noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-381/fr/asm_avx.go b/ecc/bls12-381/fr/asm_avx.go index 955f559799..1cc06c6e8d 100644 --- a/ecc/bls12-381/fr/asm_avx.go +++ b/ecc/bls12-381/fr/asm_avx.go @@ -1,5 +1,4 @@ //go:build !noavx -// +build !noavx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-381/fr/asm_noadx.go b/ecc/bls12-381/fr/asm_noadx.go index 7f52ffa197..b784a24247 100644 --- a/ecc/bls12-381/fr/asm_noadx.go +++ b/ecc/bls12-381/fr/asm_noadx.go @@ -1,5 +1,4 @@ //go:build noadx -// +build noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-381/fr/asm_noavx.go b/ecc/bls12-381/fr/asm_noavx.go index e5a5b1f2cc..66bfc00772 100644 --- a/ecc/bls12-381/fr/asm_noavx.go +++ b/ecc/bls12-381/fr/asm_noavx.go @@ -1,5 +1,4 @@ //go:build noavx -// +build noavx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-761/fr/element_ops_amd64.go b/ecc/bls12-381/fr/element_amd64.go similarity index 98% rename from ecc/bw6-761/fr/element_ops_amd64.go rename to ecc/bls12-381/fr/element_amd64.go index 83d40c28c1..0ddb905f7b 100644 --- a/ecc/bw6-761/fr/element_ops_amd64.go +++ b/ecc/bls12-381/fr/element_amd64.go @@ -1,5 +1,4 @@ //go:build !purego -// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-381/fr/element_amd64.s b/ecc/bls12-381/fr/element_amd64.s new file mode 100644 index 0000000000..fb00194d7e --- /dev/null +++ b/ecc/bls12-381/fr/element_amd64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 9425145785761608449 +#include "../../../field/asm/element_4w_amd64.s" + diff --git a/ecc/bls12-381/fr/element_arm64.go b/ecc/bls12-381/fr/element_arm64.go new file mode 100644 index 0000000000..e3f8d21ce7 --- /dev/null +++ b/ecc/bls12-381/fr/element_arm64.go @@ -0,0 +1,78 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 120259084260, + 15510977298029211676, + 7326335280343703402, + 5909200893219589146, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} diff --git a/ecc/bls12-381/fr/element_arm64.s b/ecc/bls12-381/fr/element_arm64.s new file mode 100644 index 0000000000..3cd8aaa667 --- /dev/null +++ b/ecc/bls12-381/fr/element_arm64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 18027907654287790676 +#include "../../../field/asm/element_4w_arm64.s" + diff --git a/ecc/bls12-381/fr/element_ops_amd64.s b/ecc/bls12-381/fr/element_ops_amd64.s deleted file mode 100644 index 6c42136a7a..0000000000 --- a/ecc/bls12-381/fr/element_ops_amd64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9425145785761608449 -#include "../../../field/asm/element_4w_amd64.s" - diff --git a/ecc/bls12-381/fr/element_ops_arm64.s b/ecc/bls12-381/fr/element_ops_arm64.s deleted file mode 100644 index 6ba54c61aa..0000000000 --- a/ecc/bls12-381/fr/element_ops_arm64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 18027907654287790676 -#include "../../../field/asm/element_4w_arm64.s" - diff --git a/ecc/bls12-381/fr/element_ops_purego.go b/ecc/bls12-381/fr/element_ops_purego.go deleted file mode 100644 index ee62abea36..0000000000 --- a/ecc/bls12-381/fr/element_ops_purego.go +++ /dev/null @@ -1,260 +0,0 @@ -//go:build !amd64 || purego -// +build !amd64 purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -import "math/bits" - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 120259084260, - 15510977298029211676, - 7326335280343703402, - 5909200893219589146, - } - x.Mul(x, &y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} - -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - - var t0, t1, t2, t3 uint64 - var u0, u1, u2, u3 uint64 - { - var c0, c1, c2 uint64 - v := x[0] - u0, t0 = bits.Mul64(v, x[0]) - u1, t1 = bits.Mul64(v, x[1]) - u2, t2 = bits.Mul64(v, x[2]) - u3, t3 = bits.Mul64(v, x[3]) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, 0, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[1] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[2] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[3] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - z[0] = t0 - z[1] = t1 - z[2] = t2 - z[3] = t3 - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} diff --git a/ecc/bls12-381/fr/element_purego.go b/ecc/bls12-381/fr/element_purego.go new file mode 100644 index 0000000000..be6e50e1ff --- /dev/null +++ b/ecc/bls12-381/fr/element_purego.go @@ -0,0 +1,402 @@ +//go:build purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 120259084260, + 15510977298029211676, + 7326335280343703402, + 5909200893219589146, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/bls24-315/fr/element_ops_amd64.go b/ecc/bls12-381/fr/vector_amd64.go similarity index 84% rename from ecc/bls24-315/fr/element_ops_amd64.go rename to ecc/bls12-381/fr/vector_amd64.go index b653e80069..0164ecb382 100644 --- a/ecc/bls24-315/fr/element_ops_amd64.go +++ b/ecc/bls12-381/fr/vector_amd64.go @@ -1,5 +1,4 @@ //go:build !purego -// +build !purego // Copyright 2020 ConsenSys Software Inc. // @@ -19,32 +18,6 @@ package fr -//go:noescape -func MulBy3(x *Element) - -//go:noescape -func MulBy5(x *Element) - -//go:noescape -func MulBy13(x *Element) - -//go:noescape -func mul(res, x, y *Element) - -//go:noescape -func fromMont(res *Element) - -//go:noescape -func reduce(res *Element) - -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -// -//go:noescape -func Butterfly(a, b *Element) - // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -185,24 +158,3 @@ var ( //go:noescape func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) - -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for doc. - mul(z, x, x) - return z -} diff --git a/ecc/bls12-381/fr/vector_purego.go b/ecc/bls12-381/fr/vector_purego.go new file mode 100644 index 0000000000..d09c259806 --- /dev/null +++ b/ecc/bls12-381/fr/vector_purego.go @@ -0,0 +1,56 @@ +//go:build purego || arm64 + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/ecc/bls12-381/internal/fptower/asm.go b/ecc/bls12-381/internal/fptower/asm.go index 49751a9396..03b1160807 100644 --- a/ecc/bls12-381/internal/fptower/asm.go +++ b/ecc/bls12-381/internal/fptower/asm.go @@ -1,5 +1,4 @@ //go:build !noadx -// +build !noadx // Copyright 2020 Consensys Software Inc. // diff --git a/ecc/bls12-381/internal/fptower/asm_noadx.go b/ecc/bls12-381/internal/fptower/asm_noadx.go index c6a97081fc..ea7782392c 100644 --- a/ecc/bls12-381/internal/fptower/asm_noadx.go +++ b/ecc/bls12-381/internal/fptower/asm_noadx.go @@ -1,5 +1,4 @@ //go:build noadx -// +build noadx // Copyright 2020 Consensys Software Inc. // diff --git a/ecc/bls12-381/internal/fptower/e2_fallback.go b/ecc/bls12-381/internal/fptower/e2_fallback.go index 6fe47c4111..1b6011564f 100644 --- a/ecc/bls12-381/internal/fptower/e2_fallback.go +++ b/ecc/bls12-381/internal/fptower/e2_fallback.go @@ -1,5 +1,4 @@ //go:build !amd64 -// +build !amd64 // Copyright 2020 Consensys Software Inc. // diff --git a/ecc/bls24-315/fp/arith.go b/ecc/bls24-315/fp/arith.go index 6f281563b3..66fa667482 100644 --- a/ecc/bls24-315/fp/arith.go +++ b/ecc/bls24-315/fp/arith.go @@ -58,16 +58,3 @@ func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { hi, _ = bits.Add64(hi, e, carry) return } -func max(a int, b int) int { - if a > b { - return a - } - return b -} - -func min(a int, b int) int { - if a < b { - return a - } - return b -} diff --git a/ecc/bls24-315/fp/asm_adx.go b/ecc/bls24-315/fp/asm_adx.go index 0481989ec6..f8e29bd1a7 100644 --- a/ecc/bls24-315/fp/asm_adx.go +++ b/ecc/bls24-315/fp/asm_adx.go @@ -1,5 +1,4 @@ //go:build !noadx -// +build !noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-315/fp/asm_noadx.go b/ecc/bls24-315/fp/asm_noadx.go index 92f8cc0f42..cb6cfa0f50 100644 --- a/ecc/bls24-315/fp/asm_noadx.go +++ b/ecc/bls24-315/fp/asm_noadx.go @@ -1,5 +1,4 @@ //go:build noadx -// +build noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-317/fp/element_ops_amd64.go b/ecc/bls24-315/fp/element_amd64.go similarity index 98% rename from ecc/bls24-317/fp/element_ops_amd64.go rename to ecc/bls24-315/fp/element_amd64.go index ed2803d717..77a51ee25e 100644 --- a/ecc/bls24-317/fp/element_ops_amd64.go +++ b/ecc/bls24-315/fp/element_amd64.go @@ -1,5 +1,4 @@ //go:build !purego -// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-315/fp/element_amd64.s b/ecc/bls24-315/fp/element_amd64.s new file mode 100644 index 0000000000..00858d648e --- /dev/null +++ b/ecc/bls24-315/fp/element_amd64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 18184981773209750009 +#include "../../../field/asm/element_5w_amd64.s" + diff --git a/ecc/bls24-315/fp/element_ops_amd64.s b/ecc/bls24-315/fp/element_ops_amd64.s deleted file mode 100644 index 29314843d7..0000000000 --- a/ecc/bls24-315/fp/element_ops_amd64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 18184981773209750009 -#include "../../../field/asm/element_5w_amd64.s" - diff --git a/ecc/bls24-315/fp/element_ops_purego.go b/ecc/bls24-315/fp/element_purego.go similarity index 99% rename from ecc/bls24-315/fp/element_ops_purego.go rename to ecc/bls24-315/fp/element_purego.go index 348a99f991..92806a73c7 100644 --- a/ecc/bls24-315/fp/element_ops_purego.go +++ b/ecc/bls24-315/fp/element_purego.go @@ -1,5 +1,4 @@ -//go:build !amd64 || purego -// +build !amd64 purego +//go:build purego || arm64 // Copyright 2020 ConsenSys Software Inc. // @@ -45,15 +44,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// TODO @gbotrel fixme. -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } @@ -541,3 +531,11 @@ func (z *Element) Square(x *Element) *Element { } return z } + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/bls24-315/fp/vector.go b/ecc/bls24-315/fp/vector.go index ce61e70ea0..5d7cd31fe6 100644 --- a/ecc/bls24-315/fp/vector.go +++ b/ecc/bls24-315/fp/vector.go @@ -200,43 +200,6 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") diff --git a/ecc/bls24-315/fp/vector_purego.go b/ecc/bls24-315/fp/vector_purego.go new file mode 100644 index 0000000000..798b669887 --- /dev/null +++ b/ecc/bls24-315/fp/vector_purego.go @@ -0,0 +1,54 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/ecc/bls24-315/fr/arith.go b/ecc/bls24-315/fr/arith.go index 7cfd55da19..83c9fd9ef9 100644 --- a/ecc/bls24-315/fr/arith.go +++ b/ecc/bls24-315/fr/arith.go @@ -58,16 +58,3 @@ func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { hi, _ = bits.Add64(hi, e, carry) return } -func max(a int, b int) int { - if a > b { - return a - } - return b -} - -func min(a int, b int) int { - if a < b { - return a - } - return b -} diff --git a/ecc/bls24-315/fr/asm_adx.go b/ecc/bls24-315/fr/asm_adx.go index da061913ba..9273ea23ab 100644 --- a/ecc/bls24-315/fr/asm_adx.go +++ b/ecc/bls24-315/fr/asm_adx.go @@ -1,5 +1,4 @@ //go:build !noadx -// +build !noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-315/fr/asm_avx.go b/ecc/bls24-315/fr/asm_avx.go index 955f559799..1cc06c6e8d 100644 --- a/ecc/bls24-315/fr/asm_avx.go +++ b/ecc/bls24-315/fr/asm_avx.go @@ -1,5 +1,4 @@ //go:build !noavx -// +build !noavx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-315/fr/asm_noadx.go b/ecc/bls24-315/fr/asm_noadx.go index 7f52ffa197..b784a24247 100644 --- a/ecc/bls24-315/fr/asm_noadx.go +++ b/ecc/bls24-315/fr/asm_noadx.go @@ -1,5 +1,4 @@ //go:build noadx -// +build noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-315/fr/asm_noavx.go b/ecc/bls24-315/fr/asm_noavx.go index e5a5b1f2cc..66bfc00772 100644 --- a/ecc/bls24-315/fr/asm_noavx.go +++ b/ecc/bls24-315/fr/asm_noavx.go @@ -1,5 +1,4 @@ //go:build noavx -// +build noavx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-315/fr/element_amd64.go b/ecc/bls24-315/fr/element_amd64.go new file mode 100644 index 0000000000..0ddb905f7b --- /dev/null +++ b/ecc/bls24-315/fr/element_amd64.go @@ -0,0 +1,66 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +//go:noescape +func MulBy3(x *Element) + +//go:noescape +func MulBy5(x *Element) + +//go:noescape +func MulBy13(x *Element) + +//go:noescape +func mul(res, x, y *Element) + +//go:noescape +func fromMont(res *Element) + +//go:noescape +func reduce(res *Element) + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bls24-315/fr/element_amd64.s b/ecc/bls24-315/fr/element_amd64.s new file mode 100644 index 0000000000..fb00194d7e --- /dev/null +++ b/ecc/bls24-315/fr/element_amd64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 9425145785761608449 +#include "../../../field/asm/element_4w_amd64.s" + diff --git a/ecc/bls24-315/fr/element_arm64.go b/ecc/bls24-315/fr/element_arm64.go new file mode 100644 index 0000000000..7b52729f61 --- /dev/null +++ b/ecc/bls24-315/fr/element_arm64.go @@ -0,0 +1,78 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 16427853282514304894, + 880039980351915818, + 13098611234035318378, + 1598436289436461078, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} diff --git a/ecc/bls24-315/fr/element_arm64.s b/ecc/bls24-315/fr/element_arm64.s new file mode 100644 index 0000000000..3cd8aaa667 --- /dev/null +++ b/ecc/bls24-315/fr/element_arm64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 18027907654287790676 +#include "../../../field/asm/element_4w_arm64.s" + diff --git a/ecc/bls24-315/fr/element_ops_amd64.s b/ecc/bls24-315/fr/element_ops_amd64.s deleted file mode 100644 index 6c42136a7a..0000000000 --- a/ecc/bls24-315/fr/element_ops_amd64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9425145785761608449 -#include "../../../field/asm/element_4w_amd64.s" - diff --git a/ecc/bls24-315/fr/element_ops_arm64.go b/ecc/bls24-315/fr/element_ops_arm64.go deleted file mode 100644 index 6759e524eb..0000000000 --- a/ecc/bls24-315/fr/element_ops_arm64.go +++ /dev/null @@ -1,31 +0,0 @@ -//go:build !purego -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -//go:noescape -func Butterfly(a, b *Element) - -//go:noescape -func mul(res, x, y *Element) - -func (z *Element) Mul(x, y *Element) *Element { - mul(z, x, y) - return z -} diff --git a/ecc/bls24-315/fr/element_ops_arm64.s b/ecc/bls24-315/fr/element_ops_arm64.s deleted file mode 100644 index 6ba54c61aa..0000000000 --- a/ecc/bls24-315/fr/element_ops_arm64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 18027907654287790676 -#include "../../../field/asm/element_4w_arm64.s" - diff --git a/ecc/bls24-315/fr/element_ops_purego.go b/ecc/bls24-315/fr/element_ops_purego.go deleted file mode 100644 index e35913169e..0000000000 --- a/ecc/bls24-315/fr/element_ops_purego.go +++ /dev/null @@ -1,260 +0,0 @@ -//go:build !amd64 || purego -// +build !amd64 purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -import "math/bits" - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 16427853282514304894, - 880039980351915818, - 13098611234035318378, - 1598436289436461078, - } - x.Mul(x, &y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} - -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - - var t0, t1, t2, t3 uint64 - var u0, u1, u2, u3 uint64 - { - var c0, c1, c2 uint64 - v := x[0] - u0, t0 = bits.Mul64(v, x[0]) - u1, t1 = bits.Mul64(v, x[1]) - u2, t2 = bits.Mul64(v, x[2]) - u3, t3 = bits.Mul64(v, x[3]) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, 0, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[1] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[2] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[3] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - z[0] = t0 - z[1] = t1 - z[2] = t2 - z[3] = t3 - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} diff --git a/ecc/bls24-315/fr/element_purego.go b/ecc/bls24-315/fr/element_purego.go new file mode 100644 index 0000000000..28d51a0862 --- /dev/null +++ b/ecc/bls24-315/fr/element_purego.go @@ -0,0 +1,402 @@ +//go:build purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 16427853282514304894, + 880039980351915818, + 13098611234035318378, + 1598436289436461078, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/bls12-381/fr/element_ops_amd64.go b/ecc/bls24-315/fr/vector_amd64.go similarity index 84% rename from ecc/bls12-381/fr/element_ops_amd64.go rename to ecc/bls24-315/fr/vector_amd64.go index b653e80069..0164ecb382 100644 --- a/ecc/bls12-381/fr/element_ops_amd64.go +++ b/ecc/bls24-315/fr/vector_amd64.go @@ -1,5 +1,4 @@ //go:build !purego -// +build !purego // Copyright 2020 ConsenSys Software Inc. // @@ -19,32 +18,6 @@ package fr -//go:noescape -func MulBy3(x *Element) - -//go:noescape -func MulBy5(x *Element) - -//go:noescape -func MulBy13(x *Element) - -//go:noescape -func mul(res, x, y *Element) - -//go:noescape -func fromMont(res *Element) - -//go:noescape -func reduce(res *Element) - -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -// -//go:noescape -func Butterfly(a, b *Element) - // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -185,24 +158,3 @@ var ( //go:noescape func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) - -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for doc. - mul(z, x, x) - return z -} diff --git a/ecc/bls24-315/fr/vector_purego.go b/ecc/bls24-315/fr/vector_purego.go new file mode 100644 index 0000000000..d09c259806 --- /dev/null +++ b/ecc/bls24-315/fr/vector_purego.go @@ -0,0 +1,56 @@ +//go:build purego || arm64 + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/ecc/bls24-317/fp/arith.go b/ecc/bls24-317/fp/arith.go index 6f281563b3..66fa667482 100644 --- a/ecc/bls24-317/fp/arith.go +++ b/ecc/bls24-317/fp/arith.go @@ -58,16 +58,3 @@ func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { hi, _ = bits.Add64(hi, e, carry) return } -func max(a int, b int) int { - if a > b { - return a - } - return b -} - -func min(a int, b int) int { - if a < b { - return a - } - return b -} diff --git a/ecc/bls24-317/fp/asm_adx.go b/ecc/bls24-317/fp/asm_adx.go index 0481989ec6..f8e29bd1a7 100644 --- a/ecc/bls24-317/fp/asm_adx.go +++ b/ecc/bls24-317/fp/asm_adx.go @@ -1,5 +1,4 @@ //go:build !noadx -// +build !noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-317/fp/asm_noadx.go b/ecc/bls24-317/fp/asm_noadx.go index 92f8cc0f42..cb6cfa0f50 100644 --- a/ecc/bls24-317/fp/asm_noadx.go +++ b/ecc/bls24-317/fp/asm_noadx.go @@ -1,5 +1,4 @@ //go:build noadx -// +build noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-381/fp/element_ops_amd64.go b/ecc/bls24-317/fp/element_amd64.go similarity index 98% rename from ecc/bls12-381/fp/element_ops_amd64.go rename to ecc/bls24-317/fp/element_amd64.go index ed2803d717..77a51ee25e 100644 --- a/ecc/bls12-381/fp/element_ops_amd64.go +++ b/ecc/bls24-317/fp/element_amd64.go @@ -1,5 +1,4 @@ //go:build !purego -// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-317/fp/element_amd64.s b/ecc/bls24-317/fp/element_amd64.s new file mode 100644 index 0000000000..00858d648e --- /dev/null +++ b/ecc/bls24-317/fp/element_amd64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 18184981773209750009 +#include "../../../field/asm/element_5w_amd64.s" + diff --git a/ecc/bls24-317/fp/element_ops_amd64.s b/ecc/bls24-317/fp/element_ops_amd64.s deleted file mode 100644 index 29314843d7..0000000000 --- a/ecc/bls24-317/fp/element_ops_amd64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 18184981773209750009 -#include "../../../field/asm/element_5w_amd64.s" - diff --git a/ecc/bls24-317/fp/element_ops_purego.go b/ecc/bls24-317/fp/element_purego.go similarity index 99% rename from ecc/bls24-317/fp/element_ops_purego.go rename to ecc/bls24-317/fp/element_purego.go index 6663c98bbf..9c855b3475 100644 --- a/ecc/bls24-317/fp/element_ops_purego.go +++ b/ecc/bls24-317/fp/element_purego.go @@ -1,5 +1,4 @@ -//go:build !amd64 || purego -// +build !amd64 purego +//go:build purego || arm64 // Copyright 2020 ConsenSys Software Inc. // @@ -45,15 +44,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// TODO @gbotrel fixme. -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } @@ -541,3 +531,11 @@ func (z *Element) Square(x *Element) *Element { } return z } + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/bls24-317/fp/vector.go b/ecc/bls24-317/fp/vector.go index ce61e70ea0..5d7cd31fe6 100644 --- a/ecc/bls24-317/fp/vector.go +++ b/ecc/bls24-317/fp/vector.go @@ -200,43 +200,6 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") diff --git a/ecc/bls24-317/fp/vector_purego.go b/ecc/bls24-317/fp/vector_purego.go new file mode 100644 index 0000000000..798b669887 --- /dev/null +++ b/ecc/bls24-317/fp/vector_purego.go @@ -0,0 +1,54 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/ecc/bls24-317/fr/arith.go b/ecc/bls24-317/fr/arith.go index 7cfd55da19..83c9fd9ef9 100644 --- a/ecc/bls24-317/fr/arith.go +++ b/ecc/bls24-317/fr/arith.go @@ -58,16 +58,3 @@ func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { hi, _ = bits.Add64(hi, e, carry) return } -func max(a int, b int) int { - if a > b { - return a - } - return b -} - -func min(a int, b int) int { - if a < b { - return a - } - return b -} diff --git a/ecc/bls24-317/fr/asm_adx.go b/ecc/bls24-317/fr/asm_adx.go index da061913ba..9273ea23ab 100644 --- a/ecc/bls24-317/fr/asm_adx.go +++ b/ecc/bls24-317/fr/asm_adx.go @@ -1,5 +1,4 @@ //go:build !noadx -// +build !noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-317/fr/asm_avx.go b/ecc/bls24-317/fr/asm_avx.go index 955f559799..1cc06c6e8d 100644 --- a/ecc/bls24-317/fr/asm_avx.go +++ b/ecc/bls24-317/fr/asm_avx.go @@ -1,5 +1,4 @@ //go:build !noavx -// +build !noavx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-317/fr/asm_noadx.go b/ecc/bls24-317/fr/asm_noadx.go index 7f52ffa197..b784a24247 100644 --- a/ecc/bls24-317/fr/asm_noadx.go +++ b/ecc/bls24-317/fr/asm_noadx.go @@ -1,5 +1,4 @@ //go:build noadx -// +build noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-317/fr/asm_noavx.go b/ecc/bls24-317/fr/asm_noavx.go index e5a5b1f2cc..66bfc00772 100644 --- a/ecc/bls24-317/fr/asm_noavx.go +++ b/ecc/bls24-317/fr/asm_noavx.go @@ -1,5 +1,4 @@ //go:build noavx -// +build noavx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-317/fr/element_amd64.go b/ecc/bls24-317/fr/element_amd64.go new file mode 100644 index 0000000000..0ddb905f7b --- /dev/null +++ b/ecc/bls24-317/fr/element_amd64.go @@ -0,0 +1,66 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +//go:noescape +func MulBy3(x *Element) + +//go:noescape +func MulBy5(x *Element) + +//go:noescape +func MulBy13(x *Element) + +//go:noescape +func mul(res, x, y *Element) + +//go:noescape +func fromMont(res *Element) + +//go:noescape +func reduce(res *Element) + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bls24-317/fr/element_amd64.s b/ecc/bls24-317/fr/element_amd64.s new file mode 100644 index 0000000000..fb00194d7e --- /dev/null +++ b/ecc/bls24-317/fr/element_amd64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 9425145785761608449 +#include "../../../field/asm/element_4w_amd64.s" + diff --git a/ecc/bls24-317/fr/element_arm64.go b/ecc/bls24-317/fr/element_arm64.go new file mode 100644 index 0000000000..ab8a390bdf --- /dev/null +++ b/ecc/bls24-317/fr/element_arm64.go @@ -0,0 +1,78 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 18446744073709551568, + 10999079689622735090, + 16060824205876888138, + 3752826977836272504, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} diff --git a/ecc/bls24-317/fr/element_arm64.s b/ecc/bls24-317/fr/element_arm64.s new file mode 100644 index 0000000000..3cd8aaa667 --- /dev/null +++ b/ecc/bls24-317/fr/element_arm64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 18027907654287790676 +#include "../../../field/asm/element_4w_arm64.s" + diff --git a/ecc/bls24-317/fr/element_ops_amd64.s b/ecc/bls24-317/fr/element_ops_amd64.s deleted file mode 100644 index 6c42136a7a..0000000000 --- a/ecc/bls24-317/fr/element_ops_amd64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9425145785761608449 -#include "../../../field/asm/element_4w_amd64.s" - diff --git a/ecc/bls24-317/fr/element_ops_arm64.go b/ecc/bls24-317/fr/element_ops_arm64.go deleted file mode 100644 index 6759e524eb..0000000000 --- a/ecc/bls24-317/fr/element_ops_arm64.go +++ /dev/null @@ -1,31 +0,0 @@ -//go:build !purego -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -//go:noescape -func Butterfly(a, b *Element) - -//go:noescape -func mul(res, x, y *Element) - -func (z *Element) Mul(x, y *Element) *Element { - mul(z, x, y) - return z -} diff --git a/ecc/bls24-317/fr/element_ops_arm64.s b/ecc/bls24-317/fr/element_ops_arm64.s deleted file mode 100644 index 6ba54c61aa..0000000000 --- a/ecc/bls24-317/fr/element_ops_arm64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 18027907654287790676 -#include "../../../field/asm/element_4w_arm64.s" - diff --git a/ecc/bls24-317/fr/element_ops_purego.go b/ecc/bls24-317/fr/element_ops_purego.go deleted file mode 100644 index 57e48f309b..0000000000 --- a/ecc/bls24-317/fr/element_ops_purego.go +++ /dev/null @@ -1,260 +0,0 @@ -//go:build !amd64 || purego -// +build !amd64 purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -import "math/bits" - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 18446744073709551568, - 10999079689622735090, - 16060824205876888138, - 3752826977836272504, - } - x.Mul(x, &y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} - -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - - var t0, t1, t2, t3 uint64 - var u0, u1, u2, u3 uint64 - { - var c0, c1, c2 uint64 - v := x[0] - u0, t0 = bits.Mul64(v, x[0]) - u1, t1 = bits.Mul64(v, x[1]) - u2, t2 = bits.Mul64(v, x[2]) - u3, t3 = bits.Mul64(v, x[3]) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, 0, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[1] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[2] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[3] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - z[0] = t0 - z[1] = t1 - z[2] = t2 - z[3] = t3 - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} diff --git a/ecc/bls24-317/fr/element_purego.go b/ecc/bls24-317/fr/element_purego.go new file mode 100644 index 0000000000..af2c85d2cd --- /dev/null +++ b/ecc/bls24-317/fr/element_purego.go @@ -0,0 +1,402 @@ +//go:build purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 18446744073709551568, + 10999079689622735090, + 16060824205876888138, + 3752826977836272504, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/bls24-317/fr/element_ops_amd64.go b/ecc/bls24-317/fr/vector_amd64.go similarity index 84% rename from ecc/bls24-317/fr/element_ops_amd64.go rename to ecc/bls24-317/fr/vector_amd64.go index b653e80069..0164ecb382 100644 --- a/ecc/bls24-317/fr/element_ops_amd64.go +++ b/ecc/bls24-317/fr/vector_amd64.go @@ -1,5 +1,4 @@ //go:build !purego -// +build !purego // Copyright 2020 ConsenSys Software Inc. // @@ -19,32 +18,6 @@ package fr -//go:noescape -func MulBy3(x *Element) - -//go:noescape -func MulBy5(x *Element) - -//go:noescape -func MulBy13(x *Element) - -//go:noescape -func mul(res, x, y *Element) - -//go:noescape -func fromMont(res *Element) - -//go:noescape -func reduce(res *Element) - -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -// -//go:noescape -func Butterfly(a, b *Element) - // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -185,24 +158,3 @@ var ( //go:noescape func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) - -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for doc. - mul(z, x, x) - return z -} diff --git a/ecc/bls24-317/fr/vector_purego.go b/ecc/bls24-317/fr/vector_purego.go new file mode 100644 index 0000000000..d09c259806 --- /dev/null +++ b/ecc/bls24-317/fr/vector_purego.go @@ -0,0 +1,56 @@ +//go:build purego || arm64 + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/ecc/bn254/fp/arith.go b/ecc/bn254/fp/arith.go index 6f281563b3..66fa667482 100644 --- a/ecc/bn254/fp/arith.go +++ b/ecc/bn254/fp/arith.go @@ -58,16 +58,3 @@ func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { hi, _ = bits.Add64(hi, e, carry) return } -func max(a int, b int) int { - if a > b { - return a - } - return b -} - -func min(a int, b int) int { - if a < b { - return a - } - return b -} diff --git a/ecc/bn254/fp/asm_adx.go b/ecc/bn254/fp/asm_adx.go index 0481989ec6..f8e29bd1a7 100644 --- a/ecc/bn254/fp/asm_adx.go +++ b/ecc/bn254/fp/asm_adx.go @@ -1,5 +1,4 @@ //go:build !noadx -// +build !noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bn254/fp/asm_avx.go b/ecc/bn254/fp/asm_avx.go index cea035ee84..52fc07a325 100644 --- a/ecc/bn254/fp/asm_avx.go +++ b/ecc/bn254/fp/asm_avx.go @@ -1,5 +1,4 @@ //go:build !noavx -// +build !noavx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bn254/fp/asm_noadx.go b/ecc/bn254/fp/asm_noadx.go index 92f8cc0f42..cb6cfa0f50 100644 --- a/ecc/bn254/fp/asm_noadx.go +++ b/ecc/bn254/fp/asm_noadx.go @@ -1,5 +1,4 @@ //go:build noadx -// +build noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bn254/fp/asm_noavx.go b/ecc/bn254/fp/asm_noavx.go index 9ca08a375a..12261b1f03 100644 --- a/ecc/bn254/fp/asm_noavx.go +++ b/ecc/bn254/fp/asm_noavx.go @@ -1,5 +1,4 @@ //go:build noavx -// +build noavx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bn254/fp/element_amd64.go b/ecc/bn254/fp/element_amd64.go new file mode 100644 index 0000000000..77a51ee25e --- /dev/null +++ b/ecc/bn254/fp/element_amd64.go @@ -0,0 +1,66 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +//go:noescape +func MulBy3(x *Element) + +//go:noescape +func MulBy5(x *Element) + +//go:noescape +func MulBy13(x *Element) + +//go:noescape +func mul(res, x, y *Element) + +//go:noescape +func fromMont(res *Element) + +//go:noescape +func reduce(res *Element) + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bn254/fp/element_amd64.s b/ecc/bn254/fp/element_amd64.s new file mode 100644 index 0000000000..fb00194d7e --- /dev/null +++ b/ecc/bn254/fp/element_amd64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 9425145785761608449 +#include "../../../field/asm/element_4w_amd64.s" + diff --git a/ecc/bn254/fp/element_arm64.go b/ecc/bn254/fp/element_arm64.go new file mode 100644 index 0000000000..bd05e23e63 --- /dev/null +++ b/ecc/bn254/fp/element_arm64.go @@ -0,0 +1,78 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 529957932336199972, + 13952065197595570812, + 769406925088786211, + 2691790815622165739, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} diff --git a/ecc/bn254/fp/element_arm64.s b/ecc/bn254/fp/element_arm64.s new file mode 100644 index 0000000000..3cd8aaa667 --- /dev/null +++ b/ecc/bn254/fp/element_arm64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 18027907654287790676 +#include "../../../field/asm/element_4w_arm64.s" + diff --git a/ecc/bn254/fp/element_ops_amd64.s b/ecc/bn254/fp/element_ops_amd64.s deleted file mode 100644 index 6c42136a7a..0000000000 --- a/ecc/bn254/fp/element_ops_amd64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9425145785761608449 -#include "../../../field/asm/element_4w_amd64.s" - diff --git a/ecc/bn254/fp/element_ops_arm64.go b/ecc/bn254/fp/element_ops_arm64.go deleted file mode 100644 index 78ae87b96b..0000000000 --- a/ecc/bn254/fp/element_ops_arm64.go +++ /dev/null @@ -1,31 +0,0 @@ -//go:build !purego -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fp - -//go:noescape -func Butterfly(a, b *Element) - -//go:noescape -func mul(res, x, y *Element) - -func (z *Element) Mul(x, y *Element) *Element { - mul(z, x, y) - return z -} diff --git a/ecc/bn254/fp/element_ops_arm64.s b/ecc/bn254/fp/element_ops_arm64.s deleted file mode 100644 index 6ba54c61aa..0000000000 --- a/ecc/bn254/fp/element_ops_arm64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 18027907654287790676 -#include "../../../field/asm/element_4w_arm64.s" - diff --git a/ecc/bn254/fp/element_ops_purego.go b/ecc/bn254/fp/element_ops_purego.go deleted file mode 100644 index 6d41d6578f..0000000000 --- a/ecc/bn254/fp/element_ops_purego.go +++ /dev/null @@ -1,260 +0,0 @@ -//go:build !amd64 || purego -// +build !amd64 purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fp - -import "math/bits" - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 529957932336199972, - 13952065197595570812, - 769406925088786211, - 2691790815622165739, - } - x.Mul(x, &y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} - -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - - var t0, t1, t2, t3 uint64 - var u0, u1, u2, u3 uint64 - { - var c0, c1, c2 uint64 - v := x[0] - u0, t0 = bits.Mul64(v, x[0]) - u1, t1 = bits.Mul64(v, x[1]) - u2, t2 = bits.Mul64(v, x[2]) - u3, t3 = bits.Mul64(v, x[3]) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, 0, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[1] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[2] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[3] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - z[0] = t0 - z[1] = t1 - z[2] = t2 - z[3] = t3 - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} diff --git a/ecc/bls12-377/fp/element_ops_purego.go b/ecc/bn254/fp/element_purego.go similarity index 59% rename from ecc/bls12-377/fp/element_ops_purego.go rename to ecc/bn254/fp/element_purego.go index dcff0bf7a3..00300a2189 100644 --- a/ecc/bls12-377/fp/element_ops_purego.go +++ b/ecc/bn254/fp/element_purego.go @@ -1,5 +1,4 @@ -//go:build !amd64 || purego -// +build !amd64 purego +//go:build purego // Copyright 2020 ConsenSys Software Inc. // @@ -36,12 +35,10 @@ func MulBy5(x *Element) { // MulBy13 x *= 13 (mod q) func MulBy13(x *Element) { var y = Element{ - 1176283927673829444, - 14130787773971430395, - 11354866436980285261, - 15740727779991009548, - 14951814113394531041, - 33013799364667434, + 529957932336199972, + 13952065197595570812, + 769406925088786211, + 2691790815622165739, } x.Mul(x, &y) } @@ -54,29 +51,27 @@ func reduce(z *Element) { _reduceGeneric(z) } -// Square z = x * x (mod q) +// Mul z = x * y (mod q) // -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { - var t0, t1, t2, t3, t4, t5 uint64 - var u0, u1, u2, u3, u4, u5 uint64 + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 { var c0, c1, c2 uint64 v := x[0] - u0, t0 = bits.Mul64(v, x[0]) - u1, t1 = bits.Mul64(v, x[1]) - u2, t2 = bits.Mul64(v, x[2]) - u3, t3 = bits.Mul64(v, x[3]) - u4, t4 = bits.Mul64(v, x[4]) - u5, t5 = bits.Mul64(v, x[5]) + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) t1, c0 = bits.Add64(u0, t1, 0) t2, c0 = bits.Add64(u1, t2, c0) t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, 0, c0) + c2, _ = bits.Add64(u3, 0, c0) m := qInvNeg * t0 @@ -87,46 +82,34 @@ func (z *Element) Square(x *Element) *Element { u2, c1 = bits.Mul64(m, q2) t1, c0 = bits.Add64(t2, c1, c0) u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) t0, c0 = bits.Add64(u0, t0, 0) t1, c0 = bits.Add64(u1, t1, c0) t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) } { var c0, c1, c2 uint64 v := x[1] - u0, c1 = bits.Mul64(v, x[0]) + u0, c1 = bits.Mul64(v, y[0]) t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) + u1, c1 = bits.Mul64(v, y[1]) t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) + u2, c1 = bits.Mul64(v, y[2]) t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) + u3, c1 = bits.Mul64(v, y[3]) t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, x[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, x[5]) - t5, c0 = bits.Add64(c1, t5, c0) c2, _ = bits.Add64(0, 0, c0) t1, c0 = bits.Add64(u0, t1, 0) t2, c0 = bits.Add64(u1, t2, c0) t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) + c2, _ = bits.Add64(u3, c2, c0) m := qInvNeg * t0 @@ -137,46 +120,34 @@ func (z *Element) Square(x *Element) *Element { u2, c1 = bits.Mul64(m, q2) t1, c0 = bits.Add64(t2, c1, c0) u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) t0, c0 = bits.Add64(u0, t0, 0) t1, c0 = bits.Add64(u1, t1, c0) t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) } { var c0, c1, c2 uint64 v := x[2] - u0, c1 = bits.Mul64(v, x[0]) + u0, c1 = bits.Mul64(v, y[0]) t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) + u1, c1 = bits.Mul64(v, y[1]) t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) + u2, c1 = bits.Mul64(v, y[2]) t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) + u3, c1 = bits.Mul64(v, y[3]) t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, x[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, x[5]) - t5, c0 = bits.Add64(c1, t5, c0) c2, _ = bits.Add64(0, 0, c0) t1, c0 = bits.Add64(u0, t1, 0) t2, c0 = bits.Add64(u1, t2, c0) t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) + c2, _ = bits.Add64(u3, c2, c0) m := qInvNeg * t0 @@ -187,26 +158,114 @@ func (z *Element) Square(x *Element) *Element { u2, c1 = bits.Mul64(m, q2) t1, c0 = bits.Add64(t2, c1, c0) u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) t0, c0 = bits.Add64(u0, t0, 0) t1, c0 = bits.Add64(u1, t1, c0) t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) } { var c0, c1, c2 uint64 v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] u0, c1 = bits.Mul64(v, x[0]) t0, c0 = bits.Add64(c1, t0, 0) u1, c1 = bits.Mul64(v, x[1]) @@ -215,18 +274,12 @@ func (z *Element) Square(x *Element) *Element { t2, c0 = bits.Add64(c1, t2, c0) u3, c1 = bits.Mul64(v, x[3]) t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, x[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, x[5]) - t5, c0 = bits.Add64(c1, t5, c0) c2, _ = bits.Add64(0, 0, c0) t1, c0 = bits.Add64(u0, t1, 0) t2, c0 = bits.Add64(u1, t2, c0) t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) + c2, _ = bits.Add64(u3, c2, c0) m := qInvNeg * t0 @@ -237,26 +290,20 @@ func (z *Element) Square(x *Element) *Element { u2, c1 = bits.Mul64(m, q2) t1, c0 = bits.Add64(t2, c1, c0) u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) t0, c0 = bits.Add64(u0, t0, 0) t1, c0 = bits.Add64(u1, t1, c0) t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) } { var c0, c1, c2 uint64 - v := x[4] + v := x[2] u0, c1 = bits.Mul64(v, x[0]) t0, c0 = bits.Add64(c1, t0, 0) u1, c1 = bits.Mul64(v, x[1]) @@ -265,18 +312,12 @@ func (z *Element) Square(x *Element) *Element { t2, c0 = bits.Add64(c1, t2, c0) u3, c1 = bits.Mul64(v, x[3]) t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, x[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, x[5]) - t5, c0 = bits.Add64(c1, t5, c0) c2, _ = bits.Add64(0, 0, c0) t1, c0 = bits.Add64(u0, t1, 0) t2, c0 = bits.Add64(u1, t2, c0) t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) + c2, _ = bits.Add64(u3, c2, c0) m := qInvNeg * t0 @@ -287,26 +328,20 @@ func (z *Element) Square(x *Element) *Element { u2, c1 = bits.Mul64(m, q2) t1, c0 = bits.Add64(t2, c1, c0) u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) t0, c0 = bits.Add64(u0, t0, 0) t1, c0 = bits.Add64(u1, t1, c0) t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) } { var c0, c1, c2 uint64 - v := x[5] + v := x[3] u0, c1 = bits.Mul64(v, x[0]) t0, c0 = bits.Add64(c1, t0, 0) u1, c1 = bits.Mul64(v, x[1]) @@ -315,18 +350,12 @@ func (z *Element) Square(x *Element) *Element { t2, c0 = bits.Add64(c1, t2, c0) u3, c1 = bits.Mul64(v, x[3]) t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, x[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, x[5]) - t5, c0 = bits.Add64(c1, t5, c0) c2, _ = bits.Add64(0, 0, c0) t1, c0 = bits.Add64(u0, t1, 0) t2, c0 = bits.Add64(u1, t2, c0) t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) + c2, _ = bits.Add64(u3, c2, c0) m := qInvNeg * t0 @@ -337,29 +366,21 @@ func (z *Element) Square(x *Element) *Element { u2, c1 = bits.Mul64(m, q2) t1, c0 = bits.Add64(t2, c1, c0) u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) t0, c0 = bits.Add64(u0, t0, 0) t1, c0 = bits.Add64(u1, t1, c0) t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) } z[0] = t0 z[1] = t1 z[2] = t2 z[3] = t3 - z[4] = t4 - z[5] = t5 // if z ⩾ q → z -= q if !z.smallerThanModulus() { @@ -367,9 +388,15 @@ func (z *Element) Square(x *Element) *Element { z[0], b = bits.Sub64(z[0], q0, 0) z[1], b = bits.Sub64(z[1], q1, b) z[2], b = bits.Sub64(z[2], q2, b) - z[3], b = bits.Sub64(z[3], q3, b) - z[4], b = bits.Sub64(z[4], q4, b) - z[5], _ = bits.Sub64(z[5], q5, b) + z[3], _ = bits.Sub64(z[3], q3, b) } return z } + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/bn254/fp/element_ops_amd64.go b/ecc/bn254/fp/vector_amd64.go similarity index 84% rename from ecc/bn254/fp/element_ops_amd64.go rename to ecc/bn254/fp/vector_amd64.go index 2ab1a98399..75719deedb 100644 --- a/ecc/bn254/fp/element_ops_amd64.go +++ b/ecc/bn254/fp/vector_amd64.go @@ -1,5 +1,4 @@ //go:build !purego -// +build !purego // Copyright 2020 ConsenSys Software Inc. // @@ -19,32 +18,6 @@ package fp -//go:noescape -func MulBy3(x *Element) - -//go:noescape -func MulBy5(x *Element) - -//go:noescape -func MulBy13(x *Element) - -//go:noescape -func mul(res, x, y *Element) - -//go:noescape -func fromMont(res *Element) - -//go:noescape -func reduce(res *Element) - -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -// -//go:noescape -func Butterfly(a, b *Element) - // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -185,24 +158,3 @@ var ( //go:noescape func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) - -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for doc. - mul(z, x, x) - return z -} diff --git a/ecc/bn254/fp/vector_purego.go b/ecc/bn254/fp/vector_purego.go new file mode 100644 index 0000000000..c6d37d76f4 --- /dev/null +++ b/ecc/bn254/fp/vector_purego.go @@ -0,0 +1,56 @@ +//go:build purego || arm64 + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/ecc/bn254/fr/arith.go b/ecc/bn254/fr/arith.go index 7cfd55da19..83c9fd9ef9 100644 --- a/ecc/bn254/fr/arith.go +++ b/ecc/bn254/fr/arith.go @@ -58,16 +58,3 @@ func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { hi, _ = bits.Add64(hi, e, carry) return } -func max(a int, b int) int { - if a > b { - return a - } - return b -} - -func min(a int, b int) int { - if a < b { - return a - } - return b -} diff --git a/ecc/bn254/fr/asm_adx.go b/ecc/bn254/fr/asm_adx.go index da061913ba..9273ea23ab 100644 --- a/ecc/bn254/fr/asm_adx.go +++ b/ecc/bn254/fr/asm_adx.go @@ -1,5 +1,4 @@ //go:build !noadx -// +build !noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bn254/fr/asm_avx.go b/ecc/bn254/fr/asm_avx.go index 955f559799..1cc06c6e8d 100644 --- a/ecc/bn254/fr/asm_avx.go +++ b/ecc/bn254/fr/asm_avx.go @@ -1,5 +1,4 @@ //go:build !noavx -// +build !noavx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bn254/fr/asm_noadx.go b/ecc/bn254/fr/asm_noadx.go index 7f52ffa197..b784a24247 100644 --- a/ecc/bn254/fr/asm_noadx.go +++ b/ecc/bn254/fr/asm_noadx.go @@ -1,5 +1,4 @@ //go:build noadx -// +build noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bn254/fr/asm_noavx.go b/ecc/bn254/fr/asm_noavx.go index e5a5b1f2cc..66bfc00772 100644 --- a/ecc/bn254/fr/asm_noavx.go +++ b/ecc/bn254/fr/asm_noavx.go @@ -1,5 +1,4 @@ //go:build noavx -// +build noavx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bn254/fr/element_amd64.go b/ecc/bn254/fr/element_amd64.go new file mode 100644 index 0000000000..0ddb905f7b --- /dev/null +++ b/ecc/bn254/fr/element_amd64.go @@ -0,0 +1,66 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +//go:noescape +func MulBy3(x *Element) + +//go:noescape +func MulBy5(x *Element) + +//go:noescape +func MulBy13(x *Element) + +//go:noescape +func mul(res, x, y *Element) + +//go:noescape +func fromMont(res *Element) + +//go:noescape +func reduce(res *Element) + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bn254/fr/element_amd64.s b/ecc/bn254/fr/element_amd64.s new file mode 100644 index 0000000000..fb00194d7e --- /dev/null +++ b/ecc/bn254/fr/element_amd64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 9425145785761608449 +#include "../../../field/asm/element_4w_amd64.s" + diff --git a/ecc/bn254/fr/element_arm64.go b/ecc/bn254/fr/element_arm64.go new file mode 100644 index 0000000000..a0ccad8e0e --- /dev/null +++ b/ecc/bn254/fr/element_arm64.go @@ -0,0 +1,78 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 17868810749992763324, + 5924006745939515753, + 769406925088786241, + 2691790815622165739, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} diff --git a/ecc/bn254/fr/element_arm64.s b/ecc/bn254/fr/element_arm64.s new file mode 100644 index 0000000000..3cd8aaa667 --- /dev/null +++ b/ecc/bn254/fr/element_arm64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 18027907654287790676 +#include "../../../field/asm/element_4w_arm64.s" + diff --git a/ecc/bn254/fr/element_ops_amd64.go b/ecc/bn254/fr/element_ops_amd64.go deleted file mode 100644 index b653e80069..0000000000 --- a/ecc/bn254/fr/element_ops_amd64.go +++ /dev/null @@ -1,208 +0,0 @@ -//go:build !purego -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -//go:noescape -func MulBy3(x *Element) - -//go:noescape -func MulBy5(x *Element) - -//go:noescape -func MulBy13(x *Element) - -//go:noescape -func mul(res, x, y *Element) - -//go:noescape -func fromMont(res *Element) - -//go:noescape -func reduce(res *Element) - -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -// -//go:noescape -func Butterfly(a, b *Element) - -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - if len(a) != len(b) || len(a) != len(*vector) { - panic("vector.Add: vectors don't have the same length") - } - n := uint64(len(a)) - addVec(&(*vector)[0], &a[0], &b[0], n) -} - -//go:noescape -func addVec(res, a, b *Element, n uint64) - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - if len(a) != len(b) || len(a) != len(*vector) { - panic("vector.Sub: vectors don't have the same length") - } - subVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) -} - -//go:noescape -func subVec(res, a, b *Element, n uint64) - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - if len(a) != len(*vector) { - panic("vector.ScalarMul: vectors don't have the same length") - } - const maxN = (1 << 32) - 1 - if !supportAvx512 || uint64(len(a)) >= maxN { - // call scalarMulVecGeneric - scalarMulVecGeneric(*vector, a, b) - return - } - n := uint64(len(a)) - if n == 0 { - return - } - // the code for scalarMul is identical to mulVec; and it expects at least - // 2 elements in the vector to fill the Z registers - var bb [2]Element - bb[0] = *b - bb[1] = *b - const blockSize = 16 - scalarMulVec(&(*vector)[0], &a[0], &bb[0], n/blockSize, qInvNeg) - if n%blockSize != 0 { - // call scalarMulVecGeneric on the rest - start := n - n%blockSize - scalarMulVecGeneric((*vector)[start:], a[start:], b) - } -} - -//go:noescape -func scalarMulVec(res, a, b *Element, n uint64, qInvNeg uint64) - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - n := uint64(len(*vector)) - if n == 0 { - return - } - const minN = 16 * 7 // AVX512 slower than generic for small n - const maxN = (1 << 32) - 1 - if !supportAvx512 || n <= minN || n >= maxN { - // call sumVecGeneric - sumVecGeneric(&res, *vector) - return - } - sumVec(&res, &(*vector)[0], uint64(len(*vector))) - return -} - -//go:noescape -func sumVec(res *Element, a *Element, n uint64) - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - n := uint64(len(*vector)) - if n == 0 { - return - } - if n != uint64(len(other)) { - panic("vector.InnerProduct: vectors don't have the same length") - } - const maxN = (1 << 32) - 1 - if !supportAvx512 || n >= maxN { - // call innerProductVecGeneric - // note; we could split the vector into smaller chunks and call innerProductVec - innerProductVecGeneric(&res, *vector, other) - return - } - innerProdVec(&res[0], &(*vector)[0], &other[0], uint64(len(*vector))) - - return -} - -//go:noescape -func innerProdVec(res *uint64, a, b *Element, n uint64) - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - if len(a) != len(b) || len(a) != len(*vector) { - panic("vector.Mul: vectors don't have the same length") - } - n := uint64(len(a)) - if n == 0 { - return - } - const maxN = (1 << 32) - 1 - if !supportAvx512 || n >= maxN { - // call mulVecGeneric - mulVecGeneric(*vector, a, b) - return - } - - const blockSize = 16 - mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize, qInvNeg) - if n%blockSize != 0 { - // call mulVecGeneric on the rest - start := n - n%blockSize - mulVecGeneric((*vector)[start:], a[start:], b[start:]) - } - -} - -// Patterns use for transposing the vectors in mulVec -var ( - pattern1 = [8]uint64{0, 8, 1, 9, 2, 10, 3, 11} - pattern2 = [8]uint64{12, 4, 13, 5, 14, 6, 15, 7} - pattern3 = [8]uint64{0, 1, 8, 9, 2, 3, 10, 11} - pattern4 = [8]uint64{12, 13, 4, 5, 14, 15, 6, 7} -) - -//go:noescape -func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) - -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for doc. - mul(z, x, x) - return z -} diff --git a/ecc/bn254/fr/element_ops_amd64.s b/ecc/bn254/fr/element_ops_amd64.s deleted file mode 100644 index 6c42136a7a..0000000000 --- a/ecc/bn254/fr/element_ops_amd64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9425145785761608449 -#include "../../../field/asm/element_4w_amd64.s" - diff --git a/ecc/bn254/fr/element_ops_arm64.go b/ecc/bn254/fr/element_ops_arm64.go deleted file mode 100644 index 6759e524eb..0000000000 --- a/ecc/bn254/fr/element_ops_arm64.go +++ /dev/null @@ -1,31 +0,0 @@ -//go:build !purego -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -//go:noescape -func Butterfly(a, b *Element) - -//go:noescape -func mul(res, x, y *Element) - -func (z *Element) Mul(x, y *Element) *Element { - mul(z, x, y) - return z -} diff --git a/ecc/bn254/fr/element_ops_arm64.s b/ecc/bn254/fr/element_ops_arm64.s deleted file mode 100644 index 6ba54c61aa..0000000000 --- a/ecc/bn254/fr/element_ops_arm64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 18027907654287790676 -#include "../../../field/asm/element_4w_arm64.s" - diff --git a/ecc/bn254/fr/element_ops_purego.go b/ecc/bn254/fr/element_ops_purego.go deleted file mode 100644 index 859949859a..0000000000 --- a/ecc/bn254/fr/element_ops_purego.go +++ /dev/null @@ -1,260 +0,0 @@ -//go:build !amd64 || purego -// +build !amd64 purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -import "math/bits" - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 17868810749992763324, - 5924006745939515753, - 769406925088786241, - 2691790815622165739, - } - x.Mul(x, &y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} - -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - - var t0, t1, t2, t3 uint64 - var u0, u1, u2, u3 uint64 - { - var c0, c1, c2 uint64 - v := x[0] - u0, t0 = bits.Mul64(v, x[0]) - u1, t1 = bits.Mul64(v, x[1]) - u2, t2 = bits.Mul64(v, x[2]) - u3, t3 = bits.Mul64(v, x[3]) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, 0, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[1] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[2] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[3] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - z[0] = t0 - z[1] = t1 - z[2] = t2 - z[3] = t3 - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} diff --git a/ecc/bn254/fr/element_purego.go b/ecc/bn254/fr/element_purego.go new file mode 100644 index 0000000000..78b5d8a64c --- /dev/null +++ b/ecc/bn254/fr/element_purego.go @@ -0,0 +1,402 @@ +//go:build purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 17868810749992763324, + 5924006745939515753, + 769406925088786241, + 2691790815622165739, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/bn254/fr/vector_amd64.go b/ecc/bn254/fr/vector_amd64.go new file mode 100644 index 0000000000..0164ecb382 --- /dev/null +++ b/ecc/bn254/fr/vector_amd64.go @@ -0,0 +1,160 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + n := uint64(len(a)) + addVec(&(*vector)[0], &a[0], &b[0], n) +} + +//go:noescape +func addVec(res, a, b *Element, n uint64) + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + subVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func subVec(res, a, b *Element, n uint64) + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + if len(a) != len(*vector) { + panic("vector.ScalarMul: vectors don't have the same length") + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || uint64(len(a)) >= maxN { + // call scalarMulVecGeneric + scalarMulVecGeneric(*vector, a, b) + return + } + n := uint64(len(a)) + if n == 0 { + return + } + // the code for scalarMul is identical to mulVec; and it expects at least + // 2 elements in the vector to fill the Z registers + var bb [2]Element + bb[0] = *b + bb[1] = *b + const blockSize = 16 + scalarMulVec(&(*vector)[0], &a[0], &bb[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call scalarMulVecGeneric on the rest + start := n - n%blockSize + scalarMulVecGeneric((*vector)[start:], a[start:], b) + } +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64, qInvNeg uint64) + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) - 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return +} + +//go:noescape +func sumVec(res *Element, a *Element, n uint64) + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + if n != uint64(len(other)) { + panic("vector.InnerProduct: vectors don't have the same length") + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call innerProductVecGeneric + // note; we could split the vector into smaller chunks and call innerProductVec + innerProductVecGeneric(&res, *vector, other) + return + } + innerProdVec(&res[0], &(*vector)[0], &other[0], uint64(len(*vector))) + + return +} + +//go:noescape +func innerProdVec(res *uint64, a, b *Element, n uint64) + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Mul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call mulVecGeneric + mulVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call mulVecGeneric on the rest + start := n - n%blockSize + mulVecGeneric((*vector)[start:], a[start:], b[start:]) + } + +} + +// Patterns use for transposing the vectors in mulVec +var ( + pattern1 = [8]uint64{0, 8, 1, 9, 2, 10, 3, 11} + pattern2 = [8]uint64{12, 4, 13, 5, 14, 6, 15, 7} + pattern3 = [8]uint64{0, 1, 8, 9, 2, 3, 10, 11} + pattern4 = [8]uint64{12, 13, 4, 5, 14, 15, 6, 7} +) + +//go:noescape +func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) diff --git a/ecc/bn254/fr/vector_purego.go b/ecc/bn254/fr/vector_purego.go new file mode 100644 index 0000000000..d09c259806 --- /dev/null +++ b/ecc/bn254/fr/vector_purego.go @@ -0,0 +1,56 @@ +//go:build purego || arm64 + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/ecc/bn254/internal/fptower/asm.go b/ecc/bn254/internal/fptower/asm.go index 49751a9396..03b1160807 100644 --- a/ecc/bn254/internal/fptower/asm.go +++ b/ecc/bn254/internal/fptower/asm.go @@ -1,5 +1,4 @@ //go:build !noadx -// +build !noadx // Copyright 2020 Consensys Software Inc. // diff --git a/ecc/bn254/internal/fptower/asm_noadx.go b/ecc/bn254/internal/fptower/asm_noadx.go index c6a97081fc..ea7782392c 100644 --- a/ecc/bn254/internal/fptower/asm_noadx.go +++ b/ecc/bn254/internal/fptower/asm_noadx.go @@ -1,5 +1,4 @@ //go:build noadx -// +build noadx // Copyright 2020 Consensys Software Inc. // diff --git a/ecc/bn254/internal/fptower/e2_fallback.go b/ecc/bn254/internal/fptower/e2_fallback.go index 6fe47c4111..1b6011564f 100644 --- a/ecc/bn254/internal/fptower/e2_fallback.go +++ b/ecc/bn254/internal/fptower/e2_fallback.go @@ -1,5 +1,4 @@ //go:build !amd64 -// +build !amd64 // Copyright 2020 Consensys Software Inc. // diff --git a/ecc/bw6-633/fp/arith.go b/ecc/bw6-633/fp/arith.go index 6f281563b3..66fa667482 100644 --- a/ecc/bw6-633/fp/arith.go +++ b/ecc/bw6-633/fp/arith.go @@ -58,16 +58,3 @@ func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { hi, _ = bits.Add64(hi, e, carry) return } -func max(a int, b int) int { - if a > b { - return a - } - return b -} - -func min(a int, b int) int { - if a < b { - return a - } - return b -} diff --git a/ecc/bw6-633/fp/asm_adx.go b/ecc/bw6-633/fp/asm_adx.go index 0481989ec6..f8e29bd1a7 100644 --- a/ecc/bw6-633/fp/asm_adx.go +++ b/ecc/bw6-633/fp/asm_adx.go @@ -1,5 +1,4 @@ //go:build !noadx -// +build !noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-633/fp/asm_noadx.go b/ecc/bw6-633/fp/asm_noadx.go index 92f8cc0f42..cb6cfa0f50 100644 --- a/ecc/bw6-633/fp/asm_noadx.go +++ b/ecc/bw6-633/fp/asm_noadx.go @@ -1,5 +1,4 @@ //go:build noadx -// +build noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-633/fp/element_amd64.go b/ecc/bw6-633/fp/element_amd64.go new file mode 100644 index 0000000000..77a51ee25e --- /dev/null +++ b/ecc/bw6-633/fp/element_amd64.go @@ -0,0 +1,66 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +//go:noescape +func MulBy3(x *Element) + +//go:noescape +func MulBy5(x *Element) + +//go:noescape +func MulBy13(x *Element) + +//go:noescape +func mul(res, x, y *Element) + +//go:noescape +func fromMont(res *Element) + +//go:noescape +func reduce(res *Element) + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bw6-633/fp/element_amd64.s b/ecc/bw6-633/fp/element_amd64.s new file mode 100644 index 0000000000..4820949b69 --- /dev/null +++ b/ecc/bw6-633/fp/element_amd64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 747913930085520082 +#include "../../../field/asm/element_10w_amd64.s" + diff --git a/ecc/bw6-633/fp/element_ops_amd64.go b/ecc/bw6-633/fp/element_ops_amd64.go deleted file mode 100644 index ed2803d717..0000000000 --- a/ecc/bw6-633/fp/element_ops_amd64.go +++ /dev/null @@ -1,67 +0,0 @@ -//go:build !purego -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fp - -//go:noescape -func MulBy3(x *Element) - -//go:noescape -func MulBy5(x *Element) - -//go:noescape -func MulBy13(x *Element) - -//go:noescape -func mul(res, x, y *Element) - -//go:noescape -func fromMont(res *Element) - -//go:noescape -func reduce(res *Element) - -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -// -//go:noescape -func Butterfly(a, b *Element) - -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for doc. - mul(z, x, x) - return z -} diff --git a/ecc/bw6-633/fp/element_ops_amd64.s b/ecc/bw6-633/fp/element_ops_amd64.s deleted file mode 100644 index db6a61c53a..0000000000 --- a/ecc/bw6-633/fp/element_ops_amd64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 747913930085520082 -#include "../../../field/asm/element_10w_amd64.s" - diff --git a/ecc/bw6-633/fp/element_ops_purego.go b/ecc/bw6-633/fp/element_purego.go similarity index 99% rename from ecc/bw6-633/fp/element_ops_purego.go rename to ecc/bw6-633/fp/element_purego.go index 56c5798d50..6ba677acfe 100644 --- a/ecc/bw6-633/fp/element_ops_purego.go +++ b/ecc/bw6-633/fp/element_purego.go @@ -1,5 +1,4 @@ -//go:build !amd64 || purego -// +build !amd64 purego +//go:build purego || arm64 // Copyright 2020 ConsenSys Software Inc. // @@ -50,15 +49,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// TODO @gbotrel fixme. -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } @@ -1596,3 +1586,11 @@ func (z *Element) Square(x *Element) *Element { } return z } + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/bw6-633/fp/vector.go b/ecc/bw6-633/fp/vector.go index 90e2236c7e..9f86d6d8e7 100644 --- a/ecc/bw6-633/fp/vector.go +++ b/ecc/bw6-633/fp/vector.go @@ -205,43 +205,6 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") diff --git a/ecc/bw6-633/fp/vector_purego.go b/ecc/bw6-633/fp/vector_purego.go new file mode 100644 index 0000000000..798b669887 --- /dev/null +++ b/ecc/bw6-633/fp/vector_purego.go @@ -0,0 +1,54 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/ecc/bw6-633/fr/arith.go b/ecc/bw6-633/fr/arith.go index 7cfd55da19..83c9fd9ef9 100644 --- a/ecc/bw6-633/fr/arith.go +++ b/ecc/bw6-633/fr/arith.go @@ -58,16 +58,3 @@ func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { hi, _ = bits.Add64(hi, e, carry) return } -func max(a int, b int) int { - if a > b { - return a - } - return b -} - -func min(a int, b int) int { - if a < b { - return a - } - return b -} diff --git a/ecc/bw6-633/fr/asm_adx.go b/ecc/bw6-633/fr/asm_adx.go index da061913ba..9273ea23ab 100644 --- a/ecc/bw6-633/fr/asm_adx.go +++ b/ecc/bw6-633/fr/asm_adx.go @@ -1,5 +1,4 @@ //go:build !noadx -// +build !noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-633/fr/asm_noadx.go b/ecc/bw6-633/fr/asm_noadx.go index 7f52ffa197..b784a24247 100644 --- a/ecc/bw6-633/fr/asm_noadx.go +++ b/ecc/bw6-633/fr/asm_noadx.go @@ -1,5 +1,4 @@ //go:build noadx -// +build noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-633/fr/element_amd64.go b/ecc/bw6-633/fr/element_amd64.go new file mode 100644 index 0000000000..0ddb905f7b --- /dev/null +++ b/ecc/bw6-633/fr/element_amd64.go @@ -0,0 +1,66 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +//go:noescape +func MulBy3(x *Element) + +//go:noescape +func MulBy5(x *Element) + +//go:noescape +func MulBy13(x *Element) + +//go:noescape +func mul(res, x, y *Element) + +//go:noescape +func fromMont(res *Element) + +//go:noescape +func reduce(res *Element) + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bw6-633/fr/element_amd64.s b/ecc/bw6-633/fr/element_amd64.s new file mode 100644 index 0000000000..00858d648e --- /dev/null +++ b/ecc/bw6-633/fr/element_amd64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 18184981773209750009 +#include "../../../field/asm/element_5w_amd64.s" + diff --git a/ecc/bw6-633/fr/element_ops_amd64.s b/ecc/bw6-633/fr/element_ops_amd64.s deleted file mode 100644 index 29314843d7..0000000000 --- a/ecc/bw6-633/fr/element_ops_amd64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 18184981773209750009 -#include "../../../field/asm/element_5w_amd64.s" - diff --git a/ecc/bw6-633/fr/element_ops_purego.go b/ecc/bw6-633/fr/element_purego.go similarity index 99% rename from ecc/bw6-633/fr/element_ops_purego.go rename to ecc/bw6-633/fr/element_purego.go index aa5e785c1f..4d38ad7730 100644 --- a/ecc/bw6-633/fr/element_ops_purego.go +++ b/ecc/bw6-633/fr/element_purego.go @@ -1,5 +1,4 @@ -//go:build !amd64 || purego -// +build !amd64 purego +//go:build purego || arm64 // Copyright 2020 ConsenSys Software Inc. // @@ -45,15 +44,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// TODO @gbotrel fixme. -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } @@ -541,3 +531,11 @@ func (z *Element) Square(x *Element) *Element { } return z } + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/bw6-633/fr/vector.go b/ecc/bw6-633/fr/vector.go index e3bee5fbd3..146828f9f7 100644 --- a/ecc/bw6-633/fr/vector.go +++ b/ecc/bw6-633/fr/vector.go @@ -200,43 +200,6 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") diff --git a/ecc/bw6-633/fr/vector_purego.go b/ecc/bw6-633/fr/vector_purego.go new file mode 100644 index 0000000000..04662dde33 --- /dev/null +++ b/ecc/bw6-633/fr/vector_purego.go @@ -0,0 +1,54 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/ecc/bw6-761/fp/arith.go b/ecc/bw6-761/fp/arith.go index 6f281563b3..66fa667482 100644 --- a/ecc/bw6-761/fp/arith.go +++ b/ecc/bw6-761/fp/arith.go @@ -58,16 +58,3 @@ func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { hi, _ = bits.Add64(hi, e, carry) return } -func max(a int, b int) int { - if a > b { - return a - } - return b -} - -func min(a int, b int) int { - if a < b { - return a - } - return b -} diff --git a/ecc/bw6-761/fp/asm_adx.go b/ecc/bw6-761/fp/asm_adx.go index 0481989ec6..f8e29bd1a7 100644 --- a/ecc/bw6-761/fp/asm_adx.go +++ b/ecc/bw6-761/fp/asm_adx.go @@ -1,5 +1,4 @@ //go:build !noadx -// +build !noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-761/fp/asm_noadx.go b/ecc/bw6-761/fp/asm_noadx.go index 92f8cc0f42..cb6cfa0f50 100644 --- a/ecc/bw6-761/fp/asm_noadx.go +++ b/ecc/bw6-761/fp/asm_noadx.go @@ -1,5 +1,4 @@ //go:build noadx -// +build noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-761/fp/element_amd64.go b/ecc/bw6-761/fp/element_amd64.go new file mode 100644 index 0000000000..77a51ee25e --- /dev/null +++ b/ecc/bw6-761/fp/element_amd64.go @@ -0,0 +1,66 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +//go:noescape +func MulBy3(x *Element) + +//go:noescape +func MulBy5(x *Element) + +//go:noescape +func MulBy13(x *Element) + +//go:noescape +func mul(res, x, y *Element) + +//go:noescape +func fromMont(res *Element) + +//go:noescape +func reduce(res *Element) + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bw6-761/fp/element_amd64.s b/ecc/bw6-761/fp/element_amd64.s new file mode 100644 index 0000000000..8bee44a5e7 --- /dev/null +++ b/ecc/bw6-761/fp/element_amd64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 13892629867042773109 +#include "../../../field/asm/element_12w_amd64.s" + diff --git a/ecc/bw6-761/fp/element_ops_amd64.go b/ecc/bw6-761/fp/element_ops_amd64.go deleted file mode 100644 index ed2803d717..0000000000 --- a/ecc/bw6-761/fp/element_ops_amd64.go +++ /dev/null @@ -1,67 +0,0 @@ -//go:build !purego -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fp - -//go:noescape -func MulBy3(x *Element) - -//go:noescape -func MulBy5(x *Element) - -//go:noescape -func MulBy13(x *Element) - -//go:noescape -func mul(res, x, y *Element) - -//go:noescape -func fromMont(res *Element) - -//go:noescape -func reduce(res *Element) - -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -// -//go:noescape -func Butterfly(a, b *Element) - -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for doc. - mul(z, x, x) - return z -} diff --git a/ecc/bw6-761/fp/element_ops_amd64.s b/ecc/bw6-761/fp/element_ops_amd64.s deleted file mode 100644 index 3c8e045ed6..0000000000 --- a/ecc/bw6-761/fp/element_ops_amd64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 13892629867042773109 -#include "../../../field/asm/element_12w_amd64.s" - diff --git a/ecc/bw6-761/fp/element_ops_purego.go b/ecc/bw6-761/fp/element_purego.go similarity index 99% rename from ecc/bw6-761/fp/element_ops_purego.go rename to ecc/bw6-761/fp/element_purego.go index 9ff53651dc..21b07566fd 100644 --- a/ecc/bw6-761/fp/element_ops_purego.go +++ b/ecc/bw6-761/fp/element_purego.go @@ -1,5 +1,4 @@ -//go:build !amd64 || purego -// +build !amd64 purego +//go:build purego || arm64 // Copyright 2020 ConsenSys Software Inc. // @@ -52,15 +51,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// TODO @gbotrel fixme. -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } @@ -2186,3 +2176,11 @@ func (z *Element) Square(x *Element) *Element { } return z } + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/bw6-761/fp/vector.go b/ecc/bw6-761/fp/vector.go index 8b91076209..66d5027f64 100644 --- a/ecc/bw6-761/fp/vector.go +++ b/ecc/bw6-761/fp/vector.go @@ -207,43 +207,6 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") diff --git a/ecc/bw6-761/fp/vector_purego.go b/ecc/bw6-761/fp/vector_purego.go new file mode 100644 index 0000000000..798b669887 --- /dev/null +++ b/ecc/bw6-761/fp/vector_purego.go @@ -0,0 +1,54 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/ecc/bw6-761/fr/arith.go b/ecc/bw6-761/fr/arith.go index 7cfd55da19..83c9fd9ef9 100644 --- a/ecc/bw6-761/fr/arith.go +++ b/ecc/bw6-761/fr/arith.go @@ -58,16 +58,3 @@ func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { hi, _ = bits.Add64(hi, e, carry) return } -func max(a int, b int) int { - if a > b { - return a - } - return b -} - -func min(a int, b int) int { - if a < b { - return a - } - return b -} diff --git a/ecc/bw6-761/fr/asm_adx.go b/ecc/bw6-761/fr/asm_adx.go index da061913ba..9273ea23ab 100644 --- a/ecc/bw6-761/fr/asm_adx.go +++ b/ecc/bw6-761/fr/asm_adx.go @@ -1,5 +1,4 @@ //go:build !noadx -// +build !noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-761/fr/asm_noadx.go b/ecc/bw6-761/fr/asm_noadx.go index 7f52ffa197..b784a24247 100644 --- a/ecc/bw6-761/fr/asm_noadx.go +++ b/ecc/bw6-761/fr/asm_noadx.go @@ -1,5 +1,4 @@ //go:build noadx -// +build noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-761/fr/element_amd64.go b/ecc/bw6-761/fr/element_amd64.go new file mode 100644 index 0000000000..0ddb905f7b --- /dev/null +++ b/ecc/bw6-761/fr/element_amd64.go @@ -0,0 +1,66 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +//go:noescape +func MulBy3(x *Element) + +//go:noescape +func MulBy5(x *Element) + +//go:noescape +func MulBy13(x *Element) + +//go:noescape +func mul(res, x, y *Element) + +//go:noescape +func fromMont(res *Element) + +//go:noescape +func reduce(res *Element) + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bw6-761/fr/element_amd64.s b/ecc/bw6-761/fr/element_amd64.s new file mode 100644 index 0000000000..872eddf5d6 --- /dev/null +++ b/ecc/bw6-761/fr/element_amd64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 11124594824487954849 +#include "../../../field/asm/element_6w_amd64.s" + diff --git a/ecc/bw6-761/fr/element_arm64.go b/ecc/bw6-761/fr/element_arm64.go new file mode 100644 index 0000000000..52262cdba6 --- /dev/null +++ b/ecc/bw6-761/fr/element_arm64.go @@ -0,0 +1,80 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 1176283927673829444, + 14130787773971430395, + 11354866436980285261, + 15740727779991009548, + 14951814113394531041, + 33013799364667434, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} diff --git a/ecc/bw6-761/fr/element_arm64.s b/ecc/bw6-761/fr/element_arm64.s new file mode 100644 index 0000000000..62de3f0be7 --- /dev/null +++ b/ecc/bw6-761/fr/element_arm64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 4799084555005768587 +#include "../../../field/asm/element_6w_arm64.s" + diff --git a/ecc/bw6-761/fr/element_ops_amd64.s b/ecc/bw6-761/fr/element_ops_amd64.s deleted file mode 100644 index cabff26f70..0000000000 --- a/ecc/bw6-761/fr/element_ops_amd64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 11124594824487954849 -#include "../../../field/asm/element_6w_amd64.s" - diff --git a/ecc/bw6-761/fr/element_ops_arm64.go b/ecc/bw6-761/fr/element_ops_arm64.go deleted file mode 100644 index 6759e524eb..0000000000 --- a/ecc/bw6-761/fr/element_ops_arm64.go +++ /dev/null @@ -1,31 +0,0 @@ -//go:build !purego -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -//go:noescape -func Butterfly(a, b *Element) - -//go:noescape -func mul(res, x, y *Element) - -func (z *Element) Mul(x, y *Element) *Element { - mul(z, x, y) - return z -} diff --git a/ecc/bw6-761/fr/element_ops_arm64.s b/ecc/bw6-761/fr/element_ops_arm64.s deleted file mode 100644 index f12adf4dc5..0000000000 --- a/ecc/bw6-761/fr/element_ops_arm64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 4799084555005768587 -#include "../../../field/asm/element_6w_arm64.s" - diff --git a/ecc/bw6-761/fr/element_purego.go b/ecc/bw6-761/fr/element_purego.go new file mode 100644 index 0000000000..a67b43b678 --- /dev/null +++ b/ecc/bw6-761/fr/element_purego.go @@ -0,0 +1,704 @@ +//go:build purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 1176283927673829444, + 14130787773971430395, + 11354866436980285261, + 15740727779991009548, + 14951814113394531041, + 33013799364667434, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + var t0, t1, t2, t3, t4, t5 uint64 + var u0, u1, u2, u3, u4, u5 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + u4, t4 = bits.Mul64(v, y[4]) + u5, t5 = bits.Mul64(v, y[5]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3, t4, t5 uint64 + var u0, u1, u2, u3, u4, u5 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + u4, t4 = bits.Mul64(v, x[4]) + u5, t5 = bits.Mul64(v, x[5]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/bw6-761/fr/vector.go b/ecc/bw6-761/fr/vector.go index af400c4e47..47b20efbed 100644 --- a/ecc/bw6-761/fr/vector.go +++ b/ecc/bw6-761/fr/vector.go @@ -201,43 +201,6 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") diff --git a/ecc/bw6-761/fr/vector_purego.go b/ecc/bw6-761/fr/vector_purego.go new file mode 100644 index 0000000000..04662dde33 --- /dev/null +++ b/ecc/bw6-761/fr/vector_purego.go @@ -0,0 +1,54 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/ecc/secp256k1/fp/element_ops_purego.go b/ecc/secp256k1/fp/element_purego.go similarity index 99% rename from ecc/secp256k1/fp/element_ops_purego.go rename to ecc/secp256k1/fp/element_purego.go index 9059e82e31..3be147f51a 100644 --- a/ecc/secp256k1/fp/element_ops_purego.go +++ b/ecc/secp256k1/fp/element_purego.go @@ -41,15 +41,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// TODO @gbotrel fixme. -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } @@ -305,3 +296,11 @@ func (z *Element) Square(x *Element) *Element { } return z } + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/secp256k1/fp/vector.go b/ecc/secp256k1/fp/vector.go index fa22cb416a..c97b4283ce 100644 --- a/ecc/secp256k1/fp/vector.go +++ b/ecc/secp256k1/fp/vector.go @@ -199,43 +199,6 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") diff --git a/ecc/secp256k1/fp/vector_purego.go b/ecc/secp256k1/fp/vector_purego.go new file mode 100644 index 0000000000..798b669887 --- /dev/null +++ b/ecc/secp256k1/fp/vector_purego.go @@ -0,0 +1,54 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/ecc/secp256k1/fr/element_ops_purego.go b/ecc/secp256k1/fr/element_purego.go similarity index 99% rename from ecc/secp256k1/fr/element_ops_purego.go rename to ecc/secp256k1/fr/element_purego.go index eb7f9781e7..80b3116256 100644 --- a/ecc/secp256k1/fr/element_ops_purego.go +++ b/ecc/secp256k1/fr/element_purego.go @@ -41,15 +41,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// TODO @gbotrel fixme. -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } @@ -305,3 +296,11 @@ func (z *Element) Square(x *Element) *Element { } return z } + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/secp256k1/fr/vector.go b/ecc/secp256k1/fr/vector.go index bcc71efcd8..867cabbc3d 100644 --- a/ecc/secp256k1/fr/vector.go +++ b/ecc/secp256k1/fr/vector.go @@ -199,43 +199,6 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") diff --git a/ecc/secp256k1/fr/vector_purego.go b/ecc/secp256k1/fr/vector_purego.go new file mode 100644 index 0000000000..04662dde33 --- /dev/null +++ b/ecc/secp256k1/fr/vector_purego.go @@ -0,0 +1,54 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/ecc/stark-curve/fp/arith.go b/ecc/stark-curve/fp/arith.go index 6f281563b3..66fa667482 100644 --- a/ecc/stark-curve/fp/arith.go +++ b/ecc/stark-curve/fp/arith.go @@ -58,16 +58,3 @@ func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { hi, _ = bits.Add64(hi, e, carry) return } -func max(a int, b int) int { - if a > b { - return a - } - return b -} - -func min(a int, b int) int { - if a < b { - return a - } - return b -} diff --git a/ecc/stark-curve/fp/asm_adx.go b/ecc/stark-curve/fp/asm_adx.go index 0481989ec6..f8e29bd1a7 100644 --- a/ecc/stark-curve/fp/asm_adx.go +++ b/ecc/stark-curve/fp/asm_adx.go @@ -1,5 +1,4 @@ //go:build !noadx -// +build !noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/stark-curve/fp/asm_avx.go b/ecc/stark-curve/fp/asm_avx.go index cea035ee84..52fc07a325 100644 --- a/ecc/stark-curve/fp/asm_avx.go +++ b/ecc/stark-curve/fp/asm_avx.go @@ -1,5 +1,4 @@ //go:build !noavx -// +build !noavx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/stark-curve/fp/asm_noadx.go b/ecc/stark-curve/fp/asm_noadx.go index 92f8cc0f42..cb6cfa0f50 100644 --- a/ecc/stark-curve/fp/asm_noadx.go +++ b/ecc/stark-curve/fp/asm_noadx.go @@ -1,5 +1,4 @@ //go:build noadx -// +build noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/stark-curve/fp/asm_noavx.go b/ecc/stark-curve/fp/asm_noavx.go index 9ca08a375a..12261b1f03 100644 --- a/ecc/stark-curve/fp/asm_noavx.go +++ b/ecc/stark-curve/fp/asm_noavx.go @@ -1,5 +1,4 @@ //go:build noavx -// +build noavx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/stark-curve/fp/element_amd64.go b/ecc/stark-curve/fp/element_amd64.go new file mode 100644 index 0000000000..77a51ee25e --- /dev/null +++ b/ecc/stark-curve/fp/element_amd64.go @@ -0,0 +1,66 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +//go:noescape +func MulBy3(x *Element) + +//go:noescape +func MulBy5(x *Element) + +//go:noescape +func MulBy13(x *Element) + +//go:noescape +func mul(res, x, y *Element) + +//go:noescape +func fromMont(res *Element) + +//go:noescape +func reduce(res *Element) + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/stark-curve/fp/element_amd64.s b/ecc/stark-curve/fp/element_amd64.s new file mode 100644 index 0000000000..fb00194d7e --- /dev/null +++ b/ecc/stark-curve/fp/element_amd64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 9425145785761608449 +#include "../../../field/asm/element_4w_amd64.s" + diff --git a/ecc/stark-curve/fp/element_arm64.go b/ecc/stark-curve/fp/element_arm64.go new file mode 100644 index 0000000000..c77bc2ea84 --- /dev/null +++ b/ecc/stark-curve/fp/element_arm64.go @@ -0,0 +1,78 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 18446744073709551201, + 18446744073709551615, + 18446744073709551615, + 576460752303416432, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} diff --git a/ecc/stark-curve/fp/element_arm64.s b/ecc/stark-curve/fp/element_arm64.s new file mode 100644 index 0000000000..3cd8aaa667 --- /dev/null +++ b/ecc/stark-curve/fp/element_arm64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 18027907654287790676 +#include "../../../field/asm/element_4w_arm64.s" + diff --git a/ecc/stark-curve/fp/element_ops_amd64.s b/ecc/stark-curve/fp/element_ops_amd64.s deleted file mode 100644 index 6c42136a7a..0000000000 --- a/ecc/stark-curve/fp/element_ops_amd64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9425145785761608449 -#include "../../../field/asm/element_4w_amd64.s" - diff --git a/ecc/stark-curve/fp/element_ops_arm64.go b/ecc/stark-curve/fp/element_ops_arm64.go deleted file mode 100644 index 78ae87b96b..0000000000 --- a/ecc/stark-curve/fp/element_ops_arm64.go +++ /dev/null @@ -1,31 +0,0 @@ -//go:build !purego -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fp - -//go:noescape -func Butterfly(a, b *Element) - -//go:noescape -func mul(res, x, y *Element) - -func (z *Element) Mul(x, y *Element) *Element { - mul(z, x, y) - return z -} diff --git a/ecc/stark-curve/fp/element_ops_arm64.s b/ecc/stark-curve/fp/element_ops_arm64.s deleted file mode 100644 index 6ba54c61aa..0000000000 --- a/ecc/stark-curve/fp/element_ops_arm64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 18027907654287790676 -#include "../../../field/asm/element_4w_arm64.s" - diff --git a/ecc/stark-curve/fp/element_ops_purego.go b/ecc/stark-curve/fp/element_ops_purego.go deleted file mode 100644 index 189eb1054a..0000000000 --- a/ecc/stark-curve/fp/element_ops_purego.go +++ /dev/null @@ -1,260 +0,0 @@ -//go:build !amd64 || purego -// +build !amd64 purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fp - -import "math/bits" - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 18446744073709551201, - 18446744073709551615, - 18446744073709551615, - 576460752303416432, - } - x.Mul(x, &y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} - -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - - var t0, t1, t2, t3 uint64 - var u0, u1, u2, u3 uint64 - { - var c0, c1, c2 uint64 - v := x[0] - u0, t0 = bits.Mul64(v, x[0]) - u1, t1 = bits.Mul64(v, x[1]) - u2, t2 = bits.Mul64(v, x[2]) - u3, t3 = bits.Mul64(v, x[3]) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, 0, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[1] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[2] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[3] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - z[0] = t0 - z[1] = t1 - z[2] = t2 - z[3] = t3 - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} diff --git a/ecc/bls12-381/fp/element_ops_purego.go b/ecc/stark-curve/fp/element_purego.go similarity index 59% rename from ecc/bls12-381/fp/element_ops_purego.go rename to ecc/stark-curve/fp/element_purego.go index e818762b17..380a94c00c 100644 --- a/ecc/bls12-381/fp/element_ops_purego.go +++ b/ecc/stark-curve/fp/element_purego.go @@ -1,5 +1,4 @@ -//go:build !amd64 || purego -// +build !amd64 purego +//go:build purego // Copyright 2020 ConsenSys Software Inc. // @@ -36,12 +35,10 @@ func MulBy5(x *Element) { // MulBy13 x *= 13 (mod q) func MulBy13(x *Element) { var y = Element{ - 13438459813099623723, - 14459933216667336738, - 14900020990258308116, - 2941282712809091851, - 13639094935183769893, - 1835248516986607988, + 18446744073709551201, + 18446744073709551615, + 18446744073709551615, + 576460752303416432, } x.Mul(x, &y) } @@ -54,29 +51,27 @@ func reduce(z *Element) { _reduceGeneric(z) } -// Square z = x * x (mod q) +// Mul z = x * y (mod q) // -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { - var t0, t1, t2, t3, t4, t5 uint64 - var u0, u1, u2, u3, u4, u5 uint64 + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 { var c0, c1, c2 uint64 v := x[0] - u0, t0 = bits.Mul64(v, x[0]) - u1, t1 = bits.Mul64(v, x[1]) - u2, t2 = bits.Mul64(v, x[2]) - u3, t3 = bits.Mul64(v, x[3]) - u4, t4 = bits.Mul64(v, x[4]) - u5, t5 = bits.Mul64(v, x[5]) + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) t1, c0 = bits.Add64(u0, t1, 0) t2, c0 = bits.Add64(u1, t2, c0) t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, 0, c0) + c2, _ = bits.Add64(u3, 0, c0) m := qInvNeg * t0 @@ -87,46 +82,34 @@ func (z *Element) Square(x *Element) *Element { u2, c1 = bits.Mul64(m, q2) t1, c0 = bits.Add64(t2, c1, c0) u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) t0, c0 = bits.Add64(u0, t0, 0) t1, c0 = bits.Add64(u1, t1, c0) t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) } { var c0, c1, c2 uint64 v := x[1] - u0, c1 = bits.Mul64(v, x[0]) + u0, c1 = bits.Mul64(v, y[0]) t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) + u1, c1 = bits.Mul64(v, y[1]) t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) + u2, c1 = bits.Mul64(v, y[2]) t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) + u3, c1 = bits.Mul64(v, y[3]) t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, x[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, x[5]) - t5, c0 = bits.Add64(c1, t5, c0) c2, _ = bits.Add64(0, 0, c0) t1, c0 = bits.Add64(u0, t1, 0) t2, c0 = bits.Add64(u1, t2, c0) t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) + c2, _ = bits.Add64(u3, c2, c0) m := qInvNeg * t0 @@ -137,46 +120,34 @@ func (z *Element) Square(x *Element) *Element { u2, c1 = bits.Mul64(m, q2) t1, c0 = bits.Add64(t2, c1, c0) u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) t0, c0 = bits.Add64(u0, t0, 0) t1, c0 = bits.Add64(u1, t1, c0) t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) } { var c0, c1, c2 uint64 v := x[2] - u0, c1 = bits.Mul64(v, x[0]) + u0, c1 = bits.Mul64(v, y[0]) t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) + u1, c1 = bits.Mul64(v, y[1]) t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) + u2, c1 = bits.Mul64(v, y[2]) t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) + u3, c1 = bits.Mul64(v, y[3]) t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, x[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, x[5]) - t5, c0 = bits.Add64(c1, t5, c0) c2, _ = bits.Add64(0, 0, c0) t1, c0 = bits.Add64(u0, t1, 0) t2, c0 = bits.Add64(u1, t2, c0) t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) + c2, _ = bits.Add64(u3, c2, c0) m := qInvNeg * t0 @@ -187,26 +158,114 @@ func (z *Element) Square(x *Element) *Element { u2, c1 = bits.Mul64(m, q2) t1, c0 = bits.Add64(t2, c1, c0) u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) t0, c0 = bits.Add64(u0, t0, 0) t1, c0 = bits.Add64(u1, t1, c0) t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) } { var c0, c1, c2 uint64 v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] u0, c1 = bits.Mul64(v, x[0]) t0, c0 = bits.Add64(c1, t0, 0) u1, c1 = bits.Mul64(v, x[1]) @@ -215,18 +274,12 @@ func (z *Element) Square(x *Element) *Element { t2, c0 = bits.Add64(c1, t2, c0) u3, c1 = bits.Mul64(v, x[3]) t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, x[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, x[5]) - t5, c0 = bits.Add64(c1, t5, c0) c2, _ = bits.Add64(0, 0, c0) t1, c0 = bits.Add64(u0, t1, 0) t2, c0 = bits.Add64(u1, t2, c0) t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) + c2, _ = bits.Add64(u3, c2, c0) m := qInvNeg * t0 @@ -237,26 +290,20 @@ func (z *Element) Square(x *Element) *Element { u2, c1 = bits.Mul64(m, q2) t1, c0 = bits.Add64(t2, c1, c0) u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) t0, c0 = bits.Add64(u0, t0, 0) t1, c0 = bits.Add64(u1, t1, c0) t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) } { var c0, c1, c2 uint64 - v := x[4] + v := x[2] u0, c1 = bits.Mul64(v, x[0]) t0, c0 = bits.Add64(c1, t0, 0) u1, c1 = bits.Mul64(v, x[1]) @@ -265,18 +312,12 @@ func (z *Element) Square(x *Element) *Element { t2, c0 = bits.Add64(c1, t2, c0) u3, c1 = bits.Mul64(v, x[3]) t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, x[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, x[5]) - t5, c0 = bits.Add64(c1, t5, c0) c2, _ = bits.Add64(0, 0, c0) t1, c0 = bits.Add64(u0, t1, 0) t2, c0 = bits.Add64(u1, t2, c0) t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) + c2, _ = bits.Add64(u3, c2, c0) m := qInvNeg * t0 @@ -287,26 +328,20 @@ func (z *Element) Square(x *Element) *Element { u2, c1 = bits.Mul64(m, q2) t1, c0 = bits.Add64(t2, c1, c0) u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) t0, c0 = bits.Add64(u0, t0, 0) t1, c0 = bits.Add64(u1, t1, c0) t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) } { var c0, c1, c2 uint64 - v := x[5] + v := x[3] u0, c1 = bits.Mul64(v, x[0]) t0, c0 = bits.Add64(c1, t0, 0) u1, c1 = bits.Mul64(v, x[1]) @@ -315,18 +350,12 @@ func (z *Element) Square(x *Element) *Element { t2, c0 = bits.Add64(c1, t2, c0) u3, c1 = bits.Mul64(v, x[3]) t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, x[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, x[5]) - t5, c0 = bits.Add64(c1, t5, c0) c2, _ = bits.Add64(0, 0, c0) t1, c0 = bits.Add64(u0, t1, 0) t2, c0 = bits.Add64(u1, t2, c0) t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) + c2, _ = bits.Add64(u3, c2, c0) m := qInvNeg * t0 @@ -337,29 +366,21 @@ func (z *Element) Square(x *Element) *Element { u2, c1 = bits.Mul64(m, q2) t1, c0 = bits.Add64(t2, c1, c0) u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) t0, c0 = bits.Add64(u0, t0, 0) t1, c0 = bits.Add64(u1, t1, c0) t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) } z[0] = t0 z[1] = t1 z[2] = t2 z[3] = t3 - z[4] = t4 - z[5] = t5 // if z ⩾ q → z -= q if !z.smallerThanModulus() { @@ -367,9 +388,15 @@ func (z *Element) Square(x *Element) *Element { z[0], b = bits.Sub64(z[0], q0, 0) z[1], b = bits.Sub64(z[1], q1, b) z[2], b = bits.Sub64(z[2], q2, b) - z[3], b = bits.Sub64(z[3], q3, b) - z[4], b = bits.Sub64(z[4], q4, b) - z[5], _ = bits.Sub64(z[5], q5, b) + z[3], _ = bits.Sub64(z[3], q3, b) } return z } + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/stark-curve/fp/element_ops_amd64.go b/ecc/stark-curve/fp/vector_amd64.go similarity index 84% rename from ecc/stark-curve/fp/element_ops_amd64.go rename to ecc/stark-curve/fp/vector_amd64.go index 2ab1a98399..75719deedb 100644 --- a/ecc/stark-curve/fp/element_ops_amd64.go +++ b/ecc/stark-curve/fp/vector_amd64.go @@ -1,5 +1,4 @@ //go:build !purego -// +build !purego // Copyright 2020 ConsenSys Software Inc. // @@ -19,32 +18,6 @@ package fp -//go:noescape -func MulBy3(x *Element) - -//go:noescape -func MulBy5(x *Element) - -//go:noescape -func MulBy13(x *Element) - -//go:noescape -func mul(res, x, y *Element) - -//go:noescape -func fromMont(res *Element) - -//go:noescape -func reduce(res *Element) - -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -// -//go:noescape -func Butterfly(a, b *Element) - // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -185,24 +158,3 @@ var ( //go:noescape func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) - -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for doc. - mul(z, x, x) - return z -} diff --git a/ecc/stark-curve/fp/vector_purego.go b/ecc/stark-curve/fp/vector_purego.go new file mode 100644 index 0000000000..c6d37d76f4 --- /dev/null +++ b/ecc/stark-curve/fp/vector_purego.go @@ -0,0 +1,56 @@ +//go:build purego || arm64 + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/ecc/stark-curve/fr/arith.go b/ecc/stark-curve/fr/arith.go index 7cfd55da19..83c9fd9ef9 100644 --- a/ecc/stark-curve/fr/arith.go +++ b/ecc/stark-curve/fr/arith.go @@ -58,16 +58,3 @@ func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { hi, _ = bits.Add64(hi, e, carry) return } -func max(a int, b int) int { - if a > b { - return a - } - return b -} - -func min(a int, b int) int { - if a < b { - return a - } - return b -} diff --git a/ecc/stark-curve/fr/asm_adx.go b/ecc/stark-curve/fr/asm_adx.go index da061913ba..9273ea23ab 100644 --- a/ecc/stark-curve/fr/asm_adx.go +++ b/ecc/stark-curve/fr/asm_adx.go @@ -1,5 +1,4 @@ //go:build !noadx -// +build !noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/stark-curve/fr/asm_avx.go b/ecc/stark-curve/fr/asm_avx.go index 955f559799..1cc06c6e8d 100644 --- a/ecc/stark-curve/fr/asm_avx.go +++ b/ecc/stark-curve/fr/asm_avx.go @@ -1,5 +1,4 @@ //go:build !noavx -// +build !noavx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/stark-curve/fr/asm_noadx.go b/ecc/stark-curve/fr/asm_noadx.go index 7f52ffa197..b784a24247 100644 --- a/ecc/stark-curve/fr/asm_noadx.go +++ b/ecc/stark-curve/fr/asm_noadx.go @@ -1,5 +1,4 @@ //go:build noadx -// +build noadx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/stark-curve/fr/asm_noavx.go b/ecc/stark-curve/fr/asm_noavx.go index e5a5b1f2cc..66bfc00772 100644 --- a/ecc/stark-curve/fr/asm_noavx.go +++ b/ecc/stark-curve/fr/asm_noavx.go @@ -1,5 +1,4 @@ //go:build noavx -// +build noavx // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/stark-curve/fr/element_amd64.go b/ecc/stark-curve/fr/element_amd64.go new file mode 100644 index 0000000000..0ddb905f7b --- /dev/null +++ b/ecc/stark-curve/fr/element_amd64.go @@ -0,0 +1,66 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +//go:noescape +func MulBy3(x *Element) + +//go:noescape +func MulBy5(x *Element) + +//go:noescape +func MulBy13(x *Element) + +//go:noescape +func mul(res, x, y *Element) + +//go:noescape +func fromMont(res *Element) + +//go:noescape +func reduce(res *Element) + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/stark-curve/fr/element_amd64.s b/ecc/stark-curve/fr/element_amd64.s new file mode 100644 index 0000000000..fb00194d7e --- /dev/null +++ b/ecc/stark-curve/fr/element_amd64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 9425145785761608449 +#include "../../../field/asm/element_4w_amd64.s" + diff --git a/ecc/stark-curve/fr/element_arm64.go b/ecc/stark-curve/fr/element_arm64.go new file mode 100644 index 0000000000..474236e22f --- /dev/null +++ b/ecc/stark-curve/fr/element_arm64.go @@ -0,0 +1,78 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +//go:noescape +func mul(res, x, y *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 13231284915721003215, + 9638582829363634368, + 117, + 576460752303416433, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} diff --git a/ecc/stark-curve/fr/element_arm64.s b/ecc/stark-curve/fr/element_arm64.s new file mode 100644 index 0000000000..3cd8aaa667 --- /dev/null +++ b/ecc/stark-curve/fr/element_arm64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 18027907654287790676 +#include "../../../field/asm/element_4w_arm64.s" + diff --git a/ecc/stark-curve/fr/element_ops_amd64.go b/ecc/stark-curve/fr/element_ops_amd64.go deleted file mode 100644 index b653e80069..0000000000 --- a/ecc/stark-curve/fr/element_ops_amd64.go +++ /dev/null @@ -1,208 +0,0 @@ -//go:build !purego -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -//go:noescape -func MulBy3(x *Element) - -//go:noescape -func MulBy5(x *Element) - -//go:noescape -func MulBy13(x *Element) - -//go:noescape -func mul(res, x, y *Element) - -//go:noescape -func fromMont(res *Element) - -//go:noescape -func reduce(res *Element) - -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -// -//go:noescape -func Butterfly(a, b *Element) - -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - if len(a) != len(b) || len(a) != len(*vector) { - panic("vector.Add: vectors don't have the same length") - } - n := uint64(len(a)) - addVec(&(*vector)[0], &a[0], &b[0], n) -} - -//go:noescape -func addVec(res, a, b *Element, n uint64) - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - if len(a) != len(b) || len(a) != len(*vector) { - panic("vector.Sub: vectors don't have the same length") - } - subVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) -} - -//go:noescape -func subVec(res, a, b *Element, n uint64) - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - if len(a) != len(*vector) { - panic("vector.ScalarMul: vectors don't have the same length") - } - const maxN = (1 << 32) - 1 - if !supportAvx512 || uint64(len(a)) >= maxN { - // call scalarMulVecGeneric - scalarMulVecGeneric(*vector, a, b) - return - } - n := uint64(len(a)) - if n == 0 { - return - } - // the code for scalarMul is identical to mulVec; and it expects at least - // 2 elements in the vector to fill the Z registers - var bb [2]Element - bb[0] = *b - bb[1] = *b - const blockSize = 16 - scalarMulVec(&(*vector)[0], &a[0], &bb[0], n/blockSize, qInvNeg) - if n%blockSize != 0 { - // call scalarMulVecGeneric on the rest - start := n - n%blockSize - scalarMulVecGeneric((*vector)[start:], a[start:], b) - } -} - -//go:noescape -func scalarMulVec(res, a, b *Element, n uint64, qInvNeg uint64) - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - n := uint64(len(*vector)) - if n == 0 { - return - } - const minN = 16 * 7 // AVX512 slower than generic for small n - const maxN = (1 << 32) - 1 - if !supportAvx512 || n <= minN || n >= maxN { - // call sumVecGeneric - sumVecGeneric(&res, *vector) - return - } - sumVec(&res, &(*vector)[0], uint64(len(*vector))) - return -} - -//go:noescape -func sumVec(res *Element, a *Element, n uint64) - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - n := uint64(len(*vector)) - if n == 0 { - return - } - if n != uint64(len(other)) { - panic("vector.InnerProduct: vectors don't have the same length") - } - const maxN = (1 << 32) - 1 - if !supportAvx512 || n >= maxN { - // call innerProductVecGeneric - // note; we could split the vector into smaller chunks and call innerProductVec - innerProductVecGeneric(&res, *vector, other) - return - } - innerProdVec(&res[0], &(*vector)[0], &other[0], uint64(len(*vector))) - - return -} - -//go:noescape -func innerProdVec(res *uint64, a, b *Element, n uint64) - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - if len(a) != len(b) || len(a) != len(*vector) { - panic("vector.Mul: vectors don't have the same length") - } - n := uint64(len(a)) - if n == 0 { - return - } - const maxN = (1 << 32) - 1 - if !supportAvx512 || n >= maxN { - // call mulVecGeneric - mulVecGeneric(*vector, a, b) - return - } - - const blockSize = 16 - mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize, qInvNeg) - if n%blockSize != 0 { - // call mulVecGeneric on the rest - start := n - n%blockSize - mulVecGeneric((*vector)[start:], a[start:], b[start:]) - } - -} - -// Patterns use for transposing the vectors in mulVec -var ( - pattern1 = [8]uint64{0, 8, 1, 9, 2, 10, 3, 11} - pattern2 = [8]uint64{12, 4, 13, 5, 14, 6, 15, 7} - pattern3 = [8]uint64{0, 1, 8, 9, 2, 3, 10, 11} - pattern4 = [8]uint64{12, 13, 4, 5, 14, 15, 6, 7} -) - -//go:noescape -func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) - -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for doc. - mul(z, x, x) - return z -} diff --git a/ecc/stark-curve/fr/element_ops_amd64.s b/ecc/stark-curve/fr/element_ops_amd64.s deleted file mode 100644 index 6c42136a7a..0000000000 --- a/ecc/stark-curve/fr/element_ops_amd64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 9425145785761608449 -#include "../../../field/asm/element_4w_amd64.s" - diff --git a/ecc/stark-curve/fr/element_ops_arm64.go b/ecc/stark-curve/fr/element_ops_arm64.go deleted file mode 100644 index 6759e524eb..0000000000 --- a/ecc/stark-curve/fr/element_ops_arm64.go +++ /dev/null @@ -1,31 +0,0 @@ -//go:build !purego -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -//go:noescape -func Butterfly(a, b *Element) - -//go:noescape -func mul(res, x, y *Element) - -func (z *Element) Mul(x, y *Element) *Element { - mul(z, x, y) - return z -} diff --git a/ecc/stark-curve/fr/element_ops_arm64.s b/ecc/stark-curve/fr/element_ops_arm64.s deleted file mode 100644 index 6ba54c61aa..0000000000 --- a/ecc/stark-curve/fr/element_ops_arm64.s +++ /dev/null @@ -1,6 +0,0 @@ -// +build !purego - -// Code generated by gnark-crypto/generator. DO NOT EDIT. -// We include the hash to force the Go compiler to recompile: 18027907654287790676 -#include "../../../field/asm/element_4w_arm64.s" - diff --git a/ecc/stark-curve/fr/element_ops_purego.go b/ecc/stark-curve/fr/element_ops_purego.go deleted file mode 100644 index ac4346db5b..0000000000 --- a/ecc/stark-curve/fr/element_ops_purego.go +++ /dev/null @@ -1,260 +0,0 @@ -//go:build !amd64 || purego -// +build !amd64 purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -import "math/bits" - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 13231284915721003215, - 9638582829363634368, - 117, - 576460752303416433, - } - x.Mul(x, &y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} - -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - - var t0, t1, t2, t3 uint64 - var u0, u1, u2, u3 uint64 - { - var c0, c1, c2 uint64 - v := x[0] - u0, t0 = bits.Mul64(v, x[0]) - u1, t1 = bits.Mul64(v, x[1]) - u2, t2 = bits.Mul64(v, x[2]) - u3, t3 = bits.Mul64(v, x[3]) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, 0, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[1] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[2] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - { - var c0, c1, c2 uint64 - v := x[3] - u0, c1 = bits.Mul64(v, x[0]) - t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) - t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) - t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) - t3, c0 = bits.Add64(c1, t3, c0) - - c2, _ = bits.Add64(0, 0, c0) - t1, c0 = bits.Add64(u0, t1, 0) - t2, c0 = bits.Add64(u1, t2, c0) - t3, c0 = bits.Add64(u2, t3, c0) - c2, _ = bits.Add64(u3, c2, c0) - - m := qInvNeg * t0 - - u0, c1 = bits.Mul64(m, q0) - _, c0 = bits.Add64(t0, c1, 0) - u1, c1 = bits.Mul64(m, q1) - t0, c0 = bits.Add64(t1, c1, c0) - u2, c1 = bits.Mul64(m, q2) - t1, c0 = bits.Add64(t2, c1, c0) - u3, c1 = bits.Mul64(m, q3) - - t2, c0 = bits.Add64(0, c1, c0) - u3, _ = bits.Add64(u3, 0, c0) - t0, c0 = bits.Add64(u0, t0, 0) - t1, c0 = bits.Add64(u1, t1, c0) - t2, c0 = bits.Add64(u2, t2, c0) - c2, _ = bits.Add64(c2, 0, c0) - t2, c0 = bits.Add64(t3, t2, 0) - t3, _ = bits.Add64(u3, c2, c0) - - } - z[0] = t0 - z[1] = t1 - z[2] = t2 - z[3] = t3 - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], q0, 0) - z[1], b = bits.Sub64(z[1], q1, b) - z[2], b = bits.Sub64(z[2], q2, b) - z[3], _ = bits.Sub64(z[3], q3, b) - } - return z -} diff --git a/ecc/bw6-761/fr/element_ops_purego.go b/ecc/stark-curve/fr/element_purego.go similarity index 59% rename from ecc/bw6-761/fr/element_ops_purego.go rename to ecc/stark-curve/fr/element_purego.go index df3ef00ace..7aea910af6 100644 --- a/ecc/bw6-761/fr/element_ops_purego.go +++ b/ecc/stark-curve/fr/element_purego.go @@ -1,5 +1,4 @@ -//go:build !amd64 || purego -// +build !amd64 purego +//go:build purego // Copyright 2020 ConsenSys Software Inc. // @@ -36,12 +35,10 @@ func MulBy5(x *Element) { // MulBy13 x *= 13 (mod q) func MulBy13(x *Element) { var y = Element{ - 1176283927673829444, - 14130787773971430395, - 11354866436980285261, - 15740727779991009548, - 14951814113394531041, - 33013799364667434, + 13231284915721003215, + 9638582829363634368, + 117, + 576460752303416433, } x.Mul(x, &y) } @@ -54,29 +51,27 @@ func reduce(z *Element) { _reduceGeneric(z) } -// Square z = x * x (mod q) +// Mul z = x * y (mod q) // -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { - var t0, t1, t2, t3, t4, t5 uint64 - var u0, u1, u2, u3, u4, u5 uint64 + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 { var c0, c1, c2 uint64 v := x[0] - u0, t0 = bits.Mul64(v, x[0]) - u1, t1 = bits.Mul64(v, x[1]) - u2, t2 = bits.Mul64(v, x[2]) - u3, t3 = bits.Mul64(v, x[3]) - u4, t4 = bits.Mul64(v, x[4]) - u5, t5 = bits.Mul64(v, x[5]) + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) t1, c0 = bits.Add64(u0, t1, 0) t2, c0 = bits.Add64(u1, t2, c0) t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, 0, c0) + c2, _ = bits.Add64(u3, 0, c0) m := qInvNeg * t0 @@ -87,46 +82,34 @@ func (z *Element) Square(x *Element) *Element { u2, c1 = bits.Mul64(m, q2) t1, c0 = bits.Add64(t2, c1, c0) u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) t0, c0 = bits.Add64(u0, t0, 0) t1, c0 = bits.Add64(u1, t1, c0) t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) } { var c0, c1, c2 uint64 v := x[1] - u0, c1 = bits.Mul64(v, x[0]) + u0, c1 = bits.Mul64(v, y[0]) t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) + u1, c1 = bits.Mul64(v, y[1]) t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) + u2, c1 = bits.Mul64(v, y[2]) t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) + u3, c1 = bits.Mul64(v, y[3]) t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, x[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, x[5]) - t5, c0 = bits.Add64(c1, t5, c0) c2, _ = bits.Add64(0, 0, c0) t1, c0 = bits.Add64(u0, t1, 0) t2, c0 = bits.Add64(u1, t2, c0) t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) + c2, _ = bits.Add64(u3, c2, c0) m := qInvNeg * t0 @@ -137,46 +120,34 @@ func (z *Element) Square(x *Element) *Element { u2, c1 = bits.Mul64(m, q2) t1, c0 = bits.Add64(t2, c1, c0) u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) t0, c0 = bits.Add64(u0, t0, 0) t1, c0 = bits.Add64(u1, t1, c0) t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) } { var c0, c1, c2 uint64 v := x[2] - u0, c1 = bits.Mul64(v, x[0]) + u0, c1 = bits.Mul64(v, y[0]) t0, c0 = bits.Add64(c1, t0, 0) - u1, c1 = bits.Mul64(v, x[1]) + u1, c1 = bits.Mul64(v, y[1]) t1, c0 = bits.Add64(c1, t1, c0) - u2, c1 = bits.Mul64(v, x[2]) + u2, c1 = bits.Mul64(v, y[2]) t2, c0 = bits.Add64(c1, t2, c0) - u3, c1 = bits.Mul64(v, x[3]) + u3, c1 = bits.Mul64(v, y[3]) t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, x[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, x[5]) - t5, c0 = bits.Add64(c1, t5, c0) c2, _ = bits.Add64(0, 0, c0) t1, c0 = bits.Add64(u0, t1, 0) t2, c0 = bits.Add64(u1, t2, c0) t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) + c2, _ = bits.Add64(u3, c2, c0) m := qInvNeg * t0 @@ -187,26 +158,114 @@ func (z *Element) Square(x *Element) *Element { u2, c1 = bits.Mul64(m, q2) t1, c0 = bits.Add64(t2, c1, c0) u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) t0, c0 = bits.Add64(u0, t0, 0) t1, c0 = bits.Add64(u1, t1, c0) t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) } { var c0, c1, c2 uint64 v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] u0, c1 = bits.Mul64(v, x[0]) t0, c0 = bits.Add64(c1, t0, 0) u1, c1 = bits.Mul64(v, x[1]) @@ -215,18 +274,12 @@ func (z *Element) Square(x *Element) *Element { t2, c0 = bits.Add64(c1, t2, c0) u3, c1 = bits.Mul64(v, x[3]) t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, x[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, x[5]) - t5, c0 = bits.Add64(c1, t5, c0) c2, _ = bits.Add64(0, 0, c0) t1, c0 = bits.Add64(u0, t1, 0) t2, c0 = bits.Add64(u1, t2, c0) t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) + c2, _ = bits.Add64(u3, c2, c0) m := qInvNeg * t0 @@ -237,26 +290,20 @@ func (z *Element) Square(x *Element) *Element { u2, c1 = bits.Mul64(m, q2) t1, c0 = bits.Add64(t2, c1, c0) u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) t0, c0 = bits.Add64(u0, t0, 0) t1, c0 = bits.Add64(u1, t1, c0) t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) } { var c0, c1, c2 uint64 - v := x[4] + v := x[2] u0, c1 = bits.Mul64(v, x[0]) t0, c0 = bits.Add64(c1, t0, 0) u1, c1 = bits.Mul64(v, x[1]) @@ -265,18 +312,12 @@ func (z *Element) Square(x *Element) *Element { t2, c0 = bits.Add64(c1, t2, c0) u3, c1 = bits.Mul64(v, x[3]) t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, x[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, x[5]) - t5, c0 = bits.Add64(c1, t5, c0) c2, _ = bits.Add64(0, 0, c0) t1, c0 = bits.Add64(u0, t1, 0) t2, c0 = bits.Add64(u1, t2, c0) t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) + c2, _ = bits.Add64(u3, c2, c0) m := qInvNeg * t0 @@ -287,26 +328,20 @@ func (z *Element) Square(x *Element) *Element { u2, c1 = bits.Mul64(m, q2) t1, c0 = bits.Add64(t2, c1, c0) u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) t0, c0 = bits.Add64(u0, t0, 0) t1, c0 = bits.Add64(u1, t1, c0) t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) } { var c0, c1, c2 uint64 - v := x[5] + v := x[3] u0, c1 = bits.Mul64(v, x[0]) t0, c0 = bits.Add64(c1, t0, 0) u1, c1 = bits.Mul64(v, x[1]) @@ -315,18 +350,12 @@ func (z *Element) Square(x *Element) *Element { t2, c0 = bits.Add64(c1, t2, c0) u3, c1 = bits.Mul64(v, x[3]) t3, c0 = bits.Add64(c1, t3, c0) - u4, c1 = bits.Mul64(v, x[4]) - t4, c0 = bits.Add64(c1, t4, c0) - u5, c1 = bits.Mul64(v, x[5]) - t5, c0 = bits.Add64(c1, t5, c0) c2, _ = bits.Add64(0, 0, c0) t1, c0 = bits.Add64(u0, t1, 0) t2, c0 = bits.Add64(u1, t2, c0) t3, c0 = bits.Add64(u2, t3, c0) - t4, c0 = bits.Add64(u3, t4, c0) - t5, c0 = bits.Add64(u4, t5, c0) - c2, _ = bits.Add64(u5, c2, c0) + c2, _ = bits.Add64(u3, c2, c0) m := qInvNeg * t0 @@ -337,29 +366,21 @@ func (z *Element) Square(x *Element) *Element { u2, c1 = bits.Mul64(m, q2) t1, c0 = bits.Add64(t2, c1, c0) u3, c1 = bits.Mul64(m, q3) - t2, c0 = bits.Add64(t3, c1, c0) - u4, c1 = bits.Mul64(m, q4) - t3, c0 = bits.Add64(t4, c1, c0) - u5, c1 = bits.Mul64(m, q5) - t4, c0 = bits.Add64(0, c1, c0) - u5, _ = bits.Add64(u5, 0, c0) + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) t0, c0 = bits.Add64(u0, t0, 0) t1, c0 = bits.Add64(u1, t1, c0) t2, c0 = bits.Add64(u2, t2, c0) - t3, c0 = bits.Add64(u3, t3, c0) - t4, c0 = bits.Add64(u4, t4, c0) c2, _ = bits.Add64(c2, 0, c0) - t4, c0 = bits.Add64(t5, t4, 0) - t5, _ = bits.Add64(u5, c2, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) } z[0] = t0 z[1] = t1 z[2] = t2 z[3] = t3 - z[4] = t4 - z[5] = t5 // if z ⩾ q → z -= q if !z.smallerThanModulus() { @@ -367,9 +388,15 @@ func (z *Element) Square(x *Element) *Element { z[0], b = bits.Sub64(z[0], q0, 0) z[1], b = bits.Sub64(z[1], q1, b) z[2], b = bits.Sub64(z[2], q2, b) - z[3], b = bits.Sub64(z[3], q3, b) - z[4], b = bits.Sub64(z[4], q4, b) - z[5], _ = bits.Sub64(z[5], q5, b) + z[3], _ = bits.Sub64(z[3], q3, b) } return z } + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/ecc/stark-curve/fr/vector_amd64.go b/ecc/stark-curve/fr/vector_amd64.go new file mode 100644 index 0000000000..0164ecb382 --- /dev/null +++ b/ecc/stark-curve/fr/vector_amd64.go @@ -0,0 +1,160 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + n := uint64(len(a)) + addVec(&(*vector)[0], &a[0], &b[0], n) +} + +//go:noescape +func addVec(res, a, b *Element, n uint64) + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + subVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func subVec(res, a, b *Element, n uint64) + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + if len(a) != len(*vector) { + panic("vector.ScalarMul: vectors don't have the same length") + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || uint64(len(a)) >= maxN { + // call scalarMulVecGeneric + scalarMulVecGeneric(*vector, a, b) + return + } + n := uint64(len(a)) + if n == 0 { + return + } + // the code for scalarMul is identical to mulVec; and it expects at least + // 2 elements in the vector to fill the Z registers + var bb [2]Element + bb[0] = *b + bb[1] = *b + const blockSize = 16 + scalarMulVec(&(*vector)[0], &a[0], &bb[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call scalarMulVecGeneric on the rest + start := n - n%blockSize + scalarMulVecGeneric((*vector)[start:], a[start:], b) + } +} + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64, qInvNeg uint64) + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) - 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return +} + +//go:noescape +func sumVec(res *Element, a *Element, n uint64) + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + n := uint64(len(*vector)) + if n == 0 { + return + } + if n != uint64(len(other)) { + panic("vector.InnerProduct: vectors don't have the same length") + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call innerProductVecGeneric + // note; we could split the vector into smaller chunks and call innerProductVec + innerProductVecGeneric(&res, *vector, other) + return + } + innerProdVec(&res[0], &(*vector)[0], &other[0], uint64(len(*vector))) + + return +} + +//go:noescape +func innerProdVec(res *uint64, a, b *Element, n uint64) + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Mul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call mulVecGeneric + mulVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize, qInvNeg) + if n%blockSize != 0 { + // call mulVecGeneric on the rest + start := n - n%blockSize + mulVecGeneric((*vector)[start:], a[start:], b[start:]) + } + +} + +// Patterns use for transposing the vectors in mulVec +var ( + pattern1 = [8]uint64{0, 8, 1, 9, 2, 10, 3, 11} + pattern2 = [8]uint64{12, 4, 13, 5, 14, 6, 15, 7} + pattern3 = [8]uint64{0, 1, 8, 9, 2, 3, 10, 11} + pattern4 = [8]uint64{12, 13, 4, 5, 14, 15, 6, 7} +) + +//go:noescape +func mulVec(res, a, b *Element, n uint64, qInvNeg uint64) diff --git a/ecc/stark-curve/fr/vector_purego.go b/ecc/stark-curve/fr/vector_purego.go new file mode 100644 index 0000000000..d09c259806 --- /dev/null +++ b/ecc/stark-curve/fr/vector_purego.go @@ -0,0 +1,56 @@ +//go:build purego || arm64 + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index ead32ae665..f3b23f31d6 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -108,7 +108,6 @@ func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { } f.generateButterfly() - f.generateMul() return nil diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index f9617f6e1d..1efe7d8b32 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -52,9 +52,6 @@ type FieldConfig struct { QInverse []uint64 QMinusOneHalvedP []uint64 // ((q-1) / 2 ) + 1 Mu uint64 // mu = 2^288 / q for 4.5 word barrett reduction - ASM bool - ASMVector bool - ASMArm bool RSquare []uint64 One, Thirteen []uint64 LegendreExponent string // big.Int to base16 string @@ -75,6 +72,12 @@ type FieldConfig struct { SqrtSMinusOneOver2Data *addchain.AddChainData SqrtQ3Mod4ExponentData *addchain.AddChainData UseAddChain bool + + // asm code generation + GenerateOpsAMD64 bool + GenerateOpsARM64 bool + GenerateVectorOpsAMD64 bool + GenerateVectorOpsARM64 bool } // NewFieldConfig returns a data structure with needed information to generate apis for field element @@ -262,9 +265,10 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) // note: to simplify output files generated, we generated ASM code only for // moduli that meet the condition F.NoCarry // asm code generation for moduli with more than 6 words can be optimized further - F.ASM = F.NoCarry && F.NbWords <= 12 && F.NbWords > 1 - F.ASMVector = F.ASM && F.NbWords == 4 && F.NbBits > 225 - F.ASMArm = F.ASMVector || (F.NbWords == 6) + F.GenerateOpsAMD64 = F.NoCarry && F.NbWords <= 12 && F.NbWords > 1 + F.GenerateVectorOpsAMD64 = F.GenerateOpsAMD64 && F.NbWords == 4 && F.NbBits > 225 + F.GenerateOpsARM64 = F.GenerateOpsAMD64 && (F.NbWords == 6 || F.NbWords == 4) + F.GenerateVectorOpsARM64 = false // setting Mu 2^288 / q if F.NbWords == 4 { diff --git a/field/generator/generator.go b/field/generator/generator.go index 29ed67accd..d21c6632e1 100644 --- a/field/generator/generator.go +++ b/field/generator/generator.go @@ -2,7 +2,7 @@ package generator import ( "fmt" - "io" + "hash/fnv" "os" "os/exec" "path/filepath" @@ -15,6 +15,7 @@ import ( "github.com/consensys/gnark-crypto/field/generator/config" "github.com/consensys/gnark-crypto/field/generator/internal/addchain" "github.com/consensys/gnark-crypto/field/generator/internal/templates/element" + "golang.org/x/sync/errgroup" ) // GenerateFF will generate go (and .s) files in outputDir for modulus (in base 10) @@ -46,38 +47,6 @@ func GenerateFF(F *config.FieldConfig, outputDir, asmDirBuildPath, asmDirInclude element.Test, element.InverseTests, } - // output files - eName := strings.ToLower(F.ElementName) - - pathSrc := filepath.Join(outputDir, eName+".go") - pathSrcVector := filepath.Join(outputDir, "vector.go") - pathSrcFixedExp := filepath.Join(outputDir, eName+"_exp.go") - pathSrcArith := filepath.Join(outputDir, "arith.go") - pathTest := filepath.Join(outputDir, eName+"_test.go") - pathTestVector := filepath.Join(outputDir, "vector_test.go") - - // remove old format generated files - oldFiles := []string{"_mul.go", "_mul_amd64.go", - "_square.go", "_square_amd64.go", "_ops_decl.go", "_square_amd64.s", - "_mul_amd64.s", - "_mul_arm64.s", - "_mul_arm64.go", - "_ops_amd64.s", - "_ops_arm64.s", - "_ops_noasm.go", - "_mul_adx_amd64.s", - "_ops_amd64.go", - "_fuzz.go", - } - - for _, of := range oldFiles { - _ = os.Remove(filepath.Join(outputDir, eName+of)) - } - _ = os.Remove(filepath.Join(outputDir, "asm.go")) - _ = os.Remove(filepath.Join(outputDir, "asm_noadx.go")) - _ = os.Remove(filepath.Join(outputDir, "avx.go")) - _ = os.Remove(filepath.Join(outputDir, "noavx.go")) - funcs := template.FuncMap{} if F.UseAddChain { for _, f := range addchain.Functions { @@ -85,241 +54,178 @@ func GenerateFF(F *config.FieldConfig, outputDir, asmDirBuildPath, asmDirInclude } } + os.Remove(filepath.Join(outputDir, "vector_arm64.go")) + os.Remove(filepath.Join(outputDir, "exp.go")) + funcs["shorten"] = shorten funcs["ltu64"] = func(a, b uint64) bool { return a < b } - bavardOpts := []func(*bavard.Bavard) error{ - bavard.Apache2("ConsenSys Software Inc.", 2020), - bavard.Package(F.PackageName), - bavard.GeneratedBy("consensys/gnark-crypto"), - bavard.Funcs(funcs), - } + generate := func(suffix string, templates []string, opts ...option) func() error { + opt := generateOptions(opts...) + if opt.skip { + return func() error { return nil } + } + return func() error { + bavardOpts := []func(*bavard.Bavard) error{ + bavard.Apache2("ConsenSys Software Inc.", 2020), + bavard.GeneratedBy("consensys/gnark-crypto"), + bavard.Funcs(funcs), + } + if !strings.HasSuffix(suffix, ".s") { + bavardOpts = append(bavardOpts, bavard.Package(F.PackageName)) + } + if opt.buildTag != "" { + bavardOpts = append(bavardOpts, bavard.BuildTag(opt.buildTag)) + } + if suffix == ".go" { + suffix = filepath.Join(outputDir, suffix) + } else { + suffix = filepath.Join(outputDir, suffix) + } - // generate source file - if err := bavard.GenerateFromString(pathSrc, sourceFiles, F, bavardOpts...); err != nil { - return err - } + tmplData := any(F) + if opt.tmplData != nil { + tmplData = opt.tmplData + } - // generate vector - if err := bavard.GenerateFromString(pathSrcVector, []string{element.Vector}, F, bavardOpts...); err != nil { - return err + return bavard.GenerateFromString(suffix, templates, tmplData, bavardOpts...) + } } - // generate arithmetics source file - if err := bavard.GenerateFromString(pathSrcArith, []string{element.Arith}, F, bavardOpts...); err != nil { - return err + // generate asm files; + // couple of cases; + // 1. we generate arm64 and amd64 + // 2. we generate only amd64 + // 3. we generate only purego + + // sanity check + if (F.GenerateOpsARM64 && !F.GenerateOpsAMD64) || + (F.GenerateVectorOpsAMD64 && !F.GenerateOpsAMD64) || + (F.GenerateVectorOpsARM64 && !F.GenerateOpsARM64) { + panic("not implemented.") } - // generate fixed exp source file - if F.UseAddChain { - if err := bavard.GenerateFromString(pathSrcFixedExp, []string{element.FixedExp}, F, bavardOpts...); err != nil { + // get hash of the common asm files to force compiler to recompile in case of changes. + var amd64d, arm64d ASMWrapperData + var err error + + if F.GenerateOpsAMD64 { + amd64d, err = hashAndInclude(asmDirBuildPath, asmDirIncludePath, amd64.ElementASMFileName, F.NbWords) + if err != nil { return err } } - // generate test file - if err := bavard.GenerateFromString(pathTest, testFiles, F, bavardOpts...); err != nil { - return err + if F.GenerateOpsARM64 { + arm64d, err = hashAndInclude(asmDirBuildPath, asmDirIncludePath, arm64.ElementASMFileName, F.NbWords) + if err != nil { + return err + } } - if err := bavard.GenerateFromString(pathTestVector, []string{element.TestVector}, F, bavardOpts...); err != nil { - return err + // purego files have no build tags if we don't generate asm + pureGoBuildTag := "purego" + if !F.GenerateOpsAMD64 && !F.GenerateOpsARM64 { + pureGoBuildTag = "" + } else if !F.GenerateOpsARM64 { + pureGoBuildTag = "purego || arm64" + } + pureGoVectorBuildTag := "purego" + if !F.GenerateVectorOpsAMD64 && !F.GenerateVectorOpsARM64 { + pureGoVectorBuildTag = "" + } else if !F.GenerateVectorOpsARM64 { + pureGoVectorBuildTag = "purego || arm64" } - // if we generate assembly code - if F.ASM { - // generate ops.s - { - pathSrc := filepath.Join(outputDir, eName+"_ops_amd64.s") - fmt.Println("generating", pathSrc) - f, err := os.Create(pathSrc) - if err != nil { - return err - } - - _, _ = io.WriteString(f, "// +build !purego\n") - - if err := amd64.GenerateFieldWrapper(f, F, asmDirBuildPath, asmDirIncludePath); err != nil { - _ = f.Close() - return err - } - _ = f.Close() + var g errgroup.Group - // run asmfmt - // run go fmt on whole directory - cmd := exec.Command("asmfmt", "-w", pathSrc) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return err - } - } + g.Go(generate("element.go", sourceFiles)) + g.Go(generate("doc.go", []string{element.Doc})) + g.Go(generate("vector.go", []string{element.Vector})) + g.Go(generate("arith.go", []string{element.Arith})) + g.Go(generate("element_test.go", testFiles)) + g.Go(generate("vector_test.go", []string{element.TestVector})) - } + g.Go(generate("element_amd64.s", []string{element.IncludeASM}, Only(F.GenerateOpsAMD64), WithBuildTag("!purego"), WithData(amd64d))) + g.Go(generate("element_arm64.s", []string{element.IncludeASM}, Only(F.GenerateOpsARM64), WithBuildTag("!purego"), WithData(arm64d))) - if F.ASMArm { - // generate ops.s - { - pathSrc := filepath.Join(outputDir, eName+"_ops_arm64.s") - fmt.Println("generating", pathSrc) - f, err := os.Create(pathSrc) - if err != nil { - return err - } + g.Go(generate("element_amd64.go", []string{element.OpsAMD64, element.MulDoc}, Only(F.GenerateOpsAMD64), WithBuildTag("!purego"))) + g.Go(generate("element_arm64.go", []string{element.OpsARM64, element.MulNoCarry, element.Reduce}, Only(F.GenerateOpsARM64), WithBuildTag("!purego"))) - _, _ = io.WriteString(f, "// +build !purego\n") + g.Go(generate("element_purego.go", []string{element.OpsNoAsm, element.MulCIOS, element.MulNoCarry, element.Reduce, element.MulDoc}, WithBuildTag(pureGoBuildTag))) - if err := arm64.GenerateFieldWrapper(f, F, asmDirBuildPath, asmDirIncludePath); err != nil { - _ = f.Close() - return err - } - _ = f.Close() + g.Go(generate("vector_amd64.go", []string{element.VectorOpsAmd64}, Only(F.GenerateVectorOpsAMD64), WithBuildTag("!purego"))) + g.Go(generate("vector_arm64.go", []string{element.VectorOpsArm64}, Only(F.GenerateVectorOpsARM64), WithBuildTag("!purego"))) - // run asmfmt - // run go fmt on whole directory - cmd := exec.Command("asmfmt", "-w", pathSrc) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return err - } - } + g.Go(generate("vector_purego.go", []string{element.VectorOpsPureGo}, WithBuildTag(pureGoVectorBuildTag))) - } + g.Go(generate("asm_adx.go", []string{element.Asm}, Only(F.GenerateOpsAMD64), WithBuildTag("!noadx"))) + g.Go(generate("asm_noadx.go", []string{element.AsmNoAdx}, Only(F.GenerateOpsAMD64), WithBuildTag("noadx"))) + g.Go(generate("asm_avx.go", []string{element.Avx}, Only(F.GenerateVectorOpsAMD64), WithBuildTag("!noavx"))) + g.Go(generate("asm_noavx.go", []string{element.NoAvx}, Only(F.GenerateVectorOpsAMD64), WithBuildTag("noavx"))) - if F.ASM { - // generate ops_amd64.go - src := []string{ - element.MulDoc, - element.OpsAMD64, - } - pathSrc := filepath.Join(outputDir, eName+"_ops_amd64.go") - bavardOptsCpy := make([]func(*bavard.Bavard) error, len(bavardOpts)) - copy(bavardOptsCpy, bavardOpts) - if F.ASM { - bavardOptsCpy = append(bavardOptsCpy, bavard.BuildTag("!purego")) - } - if err := bavard.GenerateFromString(pathSrc, src, F, bavardOptsCpy...); err != nil { - return err - } + if F.UseAddChain { + g.Go(generate("element_exp.go", []string{element.FixedExp})) } - if F.ASMArm { - // generate ops_arm64.go - src := []string{ - element.MulDoc, - element.OpsARM64, - element.MulNoCarry, - element.Reduce, - } - pathSrc := filepath.Join(outputDir, eName+"_ops_arm64.go") - bavardOptsCpy := make([]func(*bavard.Bavard) error, len(bavardOpts)) - copy(bavardOptsCpy, bavardOpts) - if F.ASM { - bavardOptsCpy = append(bavardOptsCpy, bavard.BuildTag("!purego")) - } - if err := bavard.GenerateFromString(pathSrc, src, F, bavardOptsCpy...); err != nil { - return err - } + if err := g.Wait(); err != nil { + return err } { - // generate ops.go - src := []string{ - element.OpsNoAsm, - element.MulCIOS, - element.MulNoCarry, - element.Reduce, - element.MulDoc, - } - pathSrc := filepath.Join(outputDir, eName+"_ops_purego.go") - bavardOptsCpy := make([]func(*bavard.Bavard) error, len(bavardOpts)) - copy(bavardOptsCpy, bavardOpts) - if F.ASM { - bavardOptsCpy = append(bavardOptsCpy, bavard.BuildTag("!amd64 purego")) - } - if err := bavard.GenerateFromString(pathSrc, src, F, bavardOptsCpy...); err != nil { + // run go fmt on whole directory + cmd := exec.Command("gofmt", "-s", "-w", outputDir) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { return err } } - { - // generate doc.go - src := []string{ - element.Doc, - } - pathSrc := filepath.Join(outputDir, "doc.go") - if err := bavard.GenerateFromString(pathSrc, src, F, bavardOpts...); err != nil { + // run asmfmt on whole directory + cmd := exec.Command("asmfmt", "-w", outputDir) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { return err } } - if F.ASM { - // generate asm.go and asm_noadx.go - src := []string{ - element.Asm, - } - pathSrc := filepath.Join(outputDir, "asm_adx.go") - bavardOptsCpy := make([]func(*bavard.Bavard) error, len(bavardOpts)) - copy(bavardOptsCpy, bavardOpts) - bavardOptsCpy = append(bavardOptsCpy, bavard.BuildTag("!noadx")) - if err := bavard.GenerateFromString(pathSrc, src, F, bavardOptsCpy...); err != nil { - return err - } - } - if F.ASM { - // generate asm.go and asm_noadx.go - src := []string{ - element.AsmNoAdx, - } - pathSrc := filepath.Join(outputDir, "asm_noadx.go") - bavardOptsCpy := make([]func(*bavard.Bavard) error, len(bavardOpts)) - copy(bavardOptsCpy, bavardOpts) - bavardOptsCpy = append(bavardOptsCpy, bavard.BuildTag("noadx")) - if err := bavard.GenerateFromString(pathSrc, src, F, bavardOptsCpy...); err != nil { - return err - } - } + return nil +} - if F.ASMVector { - // generate asm.go and asm_noadx.go - src := []string{ - element.Avx, - } - pathSrc := filepath.Join(outputDir, "asm_avx.go") - bavardOptsCpy := make([]func(*bavard.Bavard) error, len(bavardOpts)) - copy(bavardOptsCpy, bavardOpts) - bavardOptsCpy = append(bavardOptsCpy, bavard.BuildTag("!noavx")) - if err := bavard.GenerateFromString(pathSrc, src, F, bavardOptsCpy...); err != nil { - return err - } - } +type ASMWrapperData struct { + IncludePath string + Hash string +} - if F.ASMVector { - // generate asm.go and asm_noadx.go - src := []string{ - element.NoAvx, - } - pathSrc := filepath.Join(outputDir, "asm_noavx.go") - bavardOptsCpy := make([]func(*bavard.Bavard) error, len(bavardOpts)) - copy(bavardOptsCpy, bavardOpts) - bavardOptsCpy = append(bavardOptsCpy, bavard.BuildTag("noavx")) - if err := bavard.GenerateFromString(pathSrc, src, F, bavardOptsCpy...); err != nil { - return err - } +func hashAndInclude(asmDirBuildPath, asmDirIncludePath, fileName string, nbWords int) (data ASMWrapperData, err error) { + fileName = fmt.Sprintf(fileName, nbWords) + // we hash the file content and include the hash in comment of the generated file + // to force the Go compiler to recompile the file if the content has changed + fData, err := os.ReadFile(filepath.Join(asmDirBuildPath, fileName)) + if err != nil { + return ASMWrapperData{}, err } - - // run go fmt on whole directory - cmd := exec.Command("gofmt", "-s", "-w", outputDir) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return err + // hash the file using FNV + hasher := fnv.New64() + hasher.Write(fData) + hash64 := hasher.Sum64() + + hash := fmt.Sprintf("%d", hash64) + includePath := filepath.Join(asmDirIncludePath, fileName) + // on windows, we replace the "\" by "/" + if filepath.Separator == '\\' { + includePath = strings.ReplaceAll(includePath, "\\", "/") } - return nil + return ASMWrapperData{ + IncludePath: includePath, + Hash: hash, + }, nil + } func shorten(input string) string { @@ -384,3 +290,38 @@ func GenerateCommonASM(nbWords int, asmDir string, hasVector bool, hasArm bool) return nil } + +type option func(*generateConfig) +type generateConfig struct { + buildTag string + skip bool + tmplData any +} + +func WithBuildTag(buildTag string) option { + return func(opt *generateConfig) { + opt.buildTag = buildTag + } +} + +func Only(condition bool) option { + return func(opt *generateConfig) { + opt.skip = !condition + } +} + +func WithData(data any) option { + return func(opt *generateConfig) { + opt.tmplData = data + } +} + +// default options +func generateOptions(opts ...option) generateConfig { + // apply options + opt := generateConfig{} + for _, option := range opts { + option(&opt) + } + return opt +} diff --git a/field/generator/internal/templates/element/arith.go b/field/generator/internal/templates/element/arith.go index 7bf8086487..06a7805588 100644 --- a/field/generator/internal/templates/element/arith.go +++ b/field/generator/internal/templates/element/arith.go @@ -45,21 +45,5 @@ func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { return } -{{- if $.UsingP20Inverse}} - func max(a int, b int) int { - if a > b { - return a - } - return b - } - - func min(a int, b int) int { - if a < b { - return a - } - return b - } - -{{- end}} ` diff --git a/field/generator/internal/templates/element/ops_asm.go b/field/generator/internal/templates/element/ops_asm.go index 8fce3add34..91c0034c57 100644 --- a/field/generator/internal/templates/element/ops_asm.go +++ b/field/generator/internal/templates/element/ops_asm.go @@ -2,9 +2,6 @@ package element // OpsAMD64 is included with AMD64 builds (regardless of architecture or if F.ASM is set) const OpsAMD64 = ` - -{{if .ASM}} - //go:noescape func MulBy3(x *{{.ElementName}}) @@ -29,150 +26,6 @@ func reduce(res *{{.ElementName}}) //go:noescape func Butterfly(a, b *{{.ElementName}}) -{{- if .ASMVector}} -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - if len(a) != len(b) || len(a) != len(*vector) { - panic("vector.Add: vectors don't have the same length") - } - n := uint64(len(a)) - addVec(&(*vector)[0], &a[0], &b[0], n) -} - -//go:noescape -func addVec(res, a, b *{{.ElementName}}, n uint64) - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - if len(a) != len(b) || len(a) != len(*vector) { - panic("vector.Sub: vectors don't have the same length") - } - subVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) -} - -//go:noescape -func subVec(res, a, b *{{.ElementName}}, n uint64) - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { - if len(a) != len(*vector) { - panic("vector.ScalarMul: vectors don't have the same length") - } - const maxN = (1 << 32) - 1 - if !supportAvx512 || uint64(len(a)) >= maxN { - // call scalarMulVecGeneric - scalarMulVecGeneric(*vector, a, b) - return - } - n := uint64(len(a)) - if n == 0 { - return - } - // the code for scalarMul is identical to mulVec; and it expects at least - // 2 elements in the vector to fill the Z registers - var bb [2]{{.ElementName}} - bb[0] = *b - bb[1] = *b - const blockSize = 16 - scalarMulVec(&(*vector)[0], &a[0], &bb[0], n/blockSize, qInvNeg) - if n % blockSize != 0 { - // call scalarMulVecGeneric on the rest - start := n - n % blockSize - scalarMulVecGeneric((*vector)[start:], a[start:], b) - } -} - -//go:noescape -func scalarMulVec(res, a, b *{{.ElementName}}, n uint64, qInvNeg uint64) - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res {{.ElementName}}) { - n := uint64(len(*vector)) - if n == 0 { - return - } - const minN = 16*7 // AVX512 slower than generic for small n - const maxN = (1 << 32) - 1 - if !supportAvx512 || n <= minN || n >= maxN { - // call sumVecGeneric - sumVecGeneric(&res, *vector) - return - } - sumVec(&res, &(*vector)[0], uint64(len(*vector))) - return -} - -//go:noescape -func sumVec(res *{{.ElementName}}, a *{{.ElementName}}, n uint64) - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res {{.ElementName}}) { - n := uint64(len(*vector)) - if n == 0 { - return - } - if n != uint64(len(other)) { - panic("vector.InnerProduct: vectors don't have the same length") - } - const maxN = (1 << 32) - 1 - if !supportAvx512 || n >= maxN { - // call innerProductVecGeneric - // note; we could split the vector into smaller chunks and call innerProductVec - innerProductVecGeneric(&res, *vector, other) - return - } - innerProdVec(&res[0], &(*vector)[0], &other[0], uint64(len(*vector))) - - return -} - -//go:noescape -func innerProdVec(res *uint64, a,b *{{.ElementName}}, n uint64) - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - if len(a) != len(b) || len(a) != len(*vector) { - panic("vector.Mul: vectors don't have the same length") - } - n := uint64(len(a)) - if n == 0 { - return - } - const maxN = (1 << 32) - 1 - if !supportAvx512 || n >= maxN { - // call mulVecGeneric - mulVecGeneric(*vector, a, b) - return - } - - const blockSize = 16 - mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize, qInvNeg) - if n % blockSize != 0 { - // call mulVecGeneric on the rest - start := n - n % blockSize - mulVecGeneric((*vector)[start:], a[start:], b[start:]) - } - -} - -// Patterns use for transposing the vectors in mulVec -var ( - pattern1 = [8]uint64{0, 8, 1, 9, 2, 10, 3, 11} - pattern2 = [8]uint64{12, 4, 13, 5, 14, 6, 15, 7} - pattern3 = [8]uint64{0, 1, 8, 9, 2, 3, 10, 11} - pattern4 = [8]uint64{12, 13, 4, 5, 14, 15, 6, 7} -) - -//go:noescape -func mulVec(res, a, b *{{.ElementName}}, n uint64, qInvNeg uint64) - -{{- end}} - // Mul z = x * y (mod q) // // x and y must be less than q @@ -191,25 +44,78 @@ func (z *{{.ElementName}}) Square(x *{{.ElementName}}) *{{.ElementName}} { return z } -{{end}} - - - ` const OpsARM64 = ` -{{if .ASMArm}} - +// Butterfly sets +// a = a + b (mod q) +// b = a - b (mod q) //go:noescape func Butterfly(a, b *{{.ElementName}}) //go:noescape func mul(res,x,y *{{.ElementName}}) +// Mul z = x * y (mod q) +// +// x and y must be less than q func (z *{{.ElementName}}) Mul(x, y *{{.ElementName}}) *{{.ElementName}} { mul(z,x,y) return z } -{{end}} +// Square z = x * x (mod q) +// +// x must be less than q +func (z *{{.ElementName}}) Square(x *{{.ElementName}}) *{{.ElementName}} { + // see Mul for doc. + mul(z, x, x) + return z +} + + +{{ $mulConsts := list 3 5 13 }} +{{- range $i := $mulConsts }} + +// MulBy{{$i}} x *= {{$i}} (mod q) +func MulBy{{$i}}(x *{{$.ElementName}}) { + {{- if eq 1 $.NbWords}} + var y {{$.ElementName}} + y.SetUint64({{$i}}) + x.Mul(x, &y) + {{- else}} + {{- if eq $i 3}} + _x := *x + x.Double(x).Add(x, &_x) + {{- else if eq $i 5}} + _x := *x + x.Double(x).Double(x).Add(x, &_x) + {{- else if eq $i 13}} + var y = {{$.ElementName}}{ + {{- range $i := $.Thirteen}} + {{$i}},{{end}} + } + x.Mul(x, &y) + {{- else }} + NOT IMPLEMENTED + {{- end}} + {{- end}} +} + +{{- end}} + +func fromMont(z *{{.ElementName}} ) { + _fromMontGeneric(z) +} + +func reduce(z *{{.ElementName}}) { + _reduceGeneric(z) +} +` + +const IncludeASM = ` + +// We include the hash to force the Go compiler to recompile: {{.Hash}} +#include "{{.IncludePath}}" + ` diff --git a/field/generator/internal/templates/element/ops_purego.go b/field/generator/internal/templates/element/ops_purego.go index de8fdc263e..498b7e1ae6 100644 --- a/field/generator/internal/templates/element/ops_purego.go +++ b/field/generator/internal/templates/element/ops_purego.go @@ -34,17 +34,6 @@ func MulBy{{$i}}(x *{{$.ElementName}}) { {{- end}} -{{- if not .ASMArm}} -// TODO @gbotrel fixme. -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *{{.ElementName}}) { - _butterflyGeneric(a, b) -} -{{- end}} - - func fromMont(z *{{.ElementName}} ) { _fromMontGeneric(z) } @@ -53,48 +42,6 @@ func reduce(z *{{.ElementName}}) { _reduceGeneric(z) } -{{- if .ASMVector}} -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res {{.ElementName}}) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res {{.ElementName}}) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - -{{- end}} - - -{{- if not .ASMArm}} // Mul z = x * y (mod q) {{- if $.NoCarry}} // @@ -114,7 +61,6 @@ func (z *{{.ElementName}}) Mul(x, y *{{.ElementName}}) *{{.ElementName}} { {{- end }} return z } -{{- end}} // Square z = x * x (mod q) {{- if $.NoCarry}} @@ -136,4 +82,11 @@ func (z *{{.ElementName}}) Square(x *{{.ElementName}}) *{{.ElementName}} { return z } +// Butterfly sets +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *{{.ElementName}}) { + _butterflyGeneric(a, b) +} + ` diff --git a/field/generator/internal/templates/element/vector.go b/field/generator/internal/templates/element/vector.go index 6407a024d5..447b429556 100644 --- a/field/generator/internal/templates/element/vector.go +++ b/field/generator/internal/templates/element/vector.go @@ -191,49 +191,6 @@ func (vector Vector) Swap(i, j int) { } -{{/* For 4 elements, we have a special assembly path and copy this in ops_pure.go */}} -{{- if not .ASMVector}} -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res {{.ElementName}}) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res {{.ElementName}}) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - -{{- end}} - - - func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go new file mode 100644 index 0000000000..e4de4f1b27 --- /dev/null +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -0,0 +1,147 @@ +package element + +const VectorOpsAmd64 = ` +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + n := uint64(len(a)) + addVec(&(*vector)[0], &a[0], &b[0], n) +} + +//go:noescape +func addVec(res, a, b *{{.ElementName}}, n uint64) + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + subVec(&(*vector)[0], &a[0], &b[0], uint64(len(a))) +} + +//go:noescape +func subVec(res, a, b *{{.ElementName}}, n uint64) + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { + if len(a) != len(*vector) { + panic("vector.ScalarMul: vectors don't have the same length") + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || uint64(len(a)) >= maxN { + // call scalarMulVecGeneric + scalarMulVecGeneric(*vector, a, b) + return + } + n := uint64(len(a)) + if n == 0 { + return + } + // the code for scalarMul is identical to mulVec; and it expects at least + // 2 elements in the vector to fill the Z registers + var bb [2]{{.ElementName}} + bb[0] = *b + bb[1] = *b + const blockSize = 16 + scalarMulVec(&(*vector)[0], &a[0], &bb[0], n/blockSize, qInvNeg) + if n % blockSize != 0 { + // call scalarMulVecGeneric on the rest + start := n - n % blockSize + scalarMulVecGeneric((*vector)[start:], a[start:], b) + } +} + +//go:noescape +func scalarMulVec(res, a, b *{{.ElementName}}, n uint64, qInvNeg uint64) + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res {{.ElementName}}) { + n := uint64(len(*vector)) + if n == 0 { + return + } + const minN = 16*7 // AVX512 slower than generic for small n + const maxN = (1 << 32) - 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return +} + +//go:noescape +func sumVec(res *{{.ElementName}}, a *{{.ElementName}}, n uint64) + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res {{.ElementName}}) { + n := uint64(len(*vector)) + if n == 0 { + return + } + if n != uint64(len(other)) { + panic("vector.InnerProduct: vectors don't have the same length") + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call innerProductVecGeneric + // note; we could split the vector into smaller chunks and call innerProductVec + innerProductVecGeneric(&res, *vector, other) + return + } + innerProdVec(&res[0], &(*vector)[0], &other[0], uint64(len(*vector))) + + return +} + +//go:noescape +func innerProdVec(res *uint64, a,b *{{.ElementName}}, n uint64) + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Mul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + const maxN = (1 << 32) - 1 + if !supportAvx512 || n >= maxN { + // call mulVecGeneric + mulVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize, qInvNeg) + if n % blockSize != 0 { + // call mulVecGeneric on the rest + start := n - n % blockSize + mulVecGeneric((*vector)[start:], a[start:], b[start:]) + } + +} + +// Patterns use for transposing the vectors in mulVec +var ( + pattern1 = [8]uint64{0, 8, 1, 9, 2, 10, 3, 11} + pattern2 = [8]uint64{12, 4, 13, 5, 14, 6, 15, 7} + pattern3 = [8]uint64{0, 1, 8, 9, 2, 3, 10, 11} + pattern4 = [8]uint64{12, 13, 4, 5, 14, 15, 6, 7} +) + +//go:noescape +func mulVec(res, a, b *{{.ElementName}}, n uint64, qInvNeg uint64) + +` + +const VectorOpsArm64 = VectorOpsPureGo diff --git a/field/generator/internal/templates/element/vector_ops_purego.go b/field/generator/internal/templates/element/vector_ops_purego.go new file mode 100644 index 0000000000..071b710587 --- /dev/null +++ b/field/generator/internal/templates/element/vector_ops_purego.go @@ -0,0 +1,40 @@ +package element + +const VectorOpsPureGo = ` +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res {{.ElementName}}) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res {{.ElementName}}) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} +` diff --git a/field/goff/cmd/root.go b/field/goff/cmd/root.go index 58d36dbdee..07a21bbc2f 100644 --- a/field/goff/cmd/root.go +++ b/field/goff/cmd/root.go @@ -73,7 +73,7 @@ func cmdGenerate(cmd *cobra.Command, args []string) { } asmDir := filepath.Join(fOutputDir, "asm") - if err := generator.GenerateCommonASM(F.NbWords, asmDir, F.ASMVector); err != nil { + if err := generator.GenerateCommonASM(F.NbWords, asmDir, F.GenerateVectorOpsAMD64, F.GenerateOpsARM64); err != nil { fmt.Printf("\n%s\n", err.Error()) os.Exit(-1) } diff --git a/field/goldilocks/element_ops_purego.go b/field/goldilocks/element_purego.go similarity index 99% rename from field/goldilocks/element_ops_purego.go rename to field/goldilocks/element_purego.go index 1985eeb161..f1090ab75f 100644 --- a/field/goldilocks/element_ops_purego.go +++ b/field/goldilocks/element_purego.go @@ -39,15 +39,6 @@ func MulBy13(x *Element) { x.Mul(x, &y) } -// TODO @gbotrel fixme. -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - func fromMont(z *Element) { _fromMontGeneric(z) } @@ -122,3 +113,11 @@ func (z *Element) Square(x *Element) *Element { return z } + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/field/goldilocks/vector.go b/field/goldilocks/vector.go index 7411cb7bf7..47b0664aa3 100644 --- a/field/goldilocks/vector.go +++ b/field/goldilocks/vector.go @@ -196,43 +196,6 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") diff --git a/field/goldilocks/vector_purego.go b/field/goldilocks/vector_purego.go new file mode 100644 index 0000000000..857d3e9f4c --- /dev/null +++ b/field/goldilocks/vector_purego.go @@ -0,0 +1,54 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package goldilocks + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/go.mod b/go.mod index e5a21a2e84..4297d50546 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/bits-and-blooms/bitset v1.14.2 - github.com/consensys/bavard v0.1.23-0.20241019150039-28659c2eb91c + github.com/consensys/bavard v0.1.23-0.20241021201139-ab3fee069cde github.com/leanovate/gopter v0.2.11 github.com/mmcloughlin/addchain v0.4.0 github.com/spf13/cobra v1.8.1 diff --git a/go.sum b/go.sum index 026f2ec62e..40ca76c748 100644 --- a/go.sum +++ b/go.sum @@ -59,6 +59,8 @@ github.com/consensys/bavard v0.1.23-0.20241015221109-a56d5bf777eb h1:yPPmCz5FvvK github.com/consensys/bavard v0.1.23-0.20241015221109-a56d5bf777eb/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/bavard v0.1.23-0.20241019150039-28659c2eb91c h1:sK5i7h6ZVAj2eK7Vt5CzSnenlsxp828qvga+X5TjSVM= github.com/consensys/bavard v0.1.23-0.20241019150039-28659c2eb91c/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= +github.com/consensys/bavard v0.1.23-0.20241021201139-ab3fee069cde h1:KXywceL5kuPe9PAQHHBvt4Kki7/XqsW7ABJI9dn4zik= +github.com/consensys/bavard v0.1.23-0.20241021201139-ab3fee069cde/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= diff --git a/internal/generator/main.go b/internal/generator/main.go index 2f84a433e7..a5232b4b7f 100644 --- a/internal/generator/main.go +++ b/internal/generator/main.go @@ -68,16 +68,16 @@ func main() { mCommon[conf.Fr.NbWords] = true mCommon[conf.Fp.NbWords] = true - if conf.Fr.ASMVector { + if conf.Fr.GenerateVectorOpsAMD64 { mVec[conf.Fr.NbWords] = true } - if conf.Fp.ASMVector { + if conf.Fp.GenerateVectorOpsAMD64 { mVec[conf.Fp.NbWords] = true } - if conf.Fr.ASMArm { + if conf.Fr.GenerateOpsARM64 { mArm[conf.Fr.NbWords] = true } - if conf.Fp.ASMArm { + if conf.Fp.GenerateOpsARM64 { mArm[conf.Fp.NbWords] = true } From 99283df9b3d844ec397b185cf559efdc9d21634f Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 21 Oct 2024 15:34:09 -0500 Subject: [PATCH 17/74] style: code cleaning --- field/generator/generator.go | 83 ++++++++++++++++--------------- field/generator/generator_test.go | 8 +-- field/goff/cmd/root.go | 14 ++++-- internal/generator/main.go | 49 ++++++------------ 4 files changed, 72 insertions(+), 82 deletions(-) diff --git a/field/generator/generator.go b/field/generator/generator.go index d21c6632e1..2b6dd2d9f0 100644 --- a/field/generator/generator.go +++ b/field/generator/generator.go @@ -236,56 +236,57 @@ func shorten(input string) string { return input } -func GenerateCommonASM(nbWords int, asmDir string, hasVector bool, hasArm bool) error { +func GenerateARM64(nbWords int, asmDir string, hasVector bool) error { os.MkdirAll(asmDir, 0755) - { - pathSrc := filepath.Join(asmDir, fmt.Sprintf(amd64.ElementASMFileName, nbWords)) + pathSrc := filepath.Join(asmDir, fmt.Sprintf(arm64.ElementASMFileName, nbWords)) - fmt.Println("generating", pathSrc) - f, err := os.Create(pathSrc) - if err != nil { - return err - } + fmt.Println("generating", pathSrc) + f, err := os.Create(pathSrc) + if err != nil { + return err + } - if err := amd64.GenerateCommonASM(f, nbWords, hasVector); err != nil { - _ = f.Close() - return err - } + if err := arm64.GenerateCommonASM(f, nbWords, hasVector); err != nil { _ = f.Close() - - // run asmfmt - // run go fmt on whole directory - cmd := exec.Command("asmfmt", "-w", pathSrc) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return err - } + return err + } + _ = f.Close() + + // run asmfmt + // run go fmt on whole directory + cmd := exec.Command("asmfmt", "-w", pathSrc) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return err } - if hasArm { - pathSrc := filepath.Join(asmDir, fmt.Sprintf(arm64.ElementASMFileName, nbWords)) + return nil +} - fmt.Println("generating", pathSrc) - f, err := os.Create(pathSrc) - if err != nil { - return err - } +func GenerateAMD64(nbWords int, asmDir string, hasVector bool) error { + os.MkdirAll(asmDir, 0755) + pathSrc := filepath.Join(asmDir, fmt.Sprintf(amd64.ElementASMFileName, nbWords)) - if err := arm64.GenerateCommonASM(f, nbWords, hasVector); err != nil { - _ = f.Close() - return err - } - _ = f.Close() + fmt.Println("generating", pathSrc) + f, err := os.Create(pathSrc) + if err != nil { + return err + } - // run asmfmt - // run go fmt on whole directory - cmd := exec.Command("asmfmt", "-w", pathSrc) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return err - } + if err := amd64.GenerateCommonASM(f, nbWords, hasVector); err != nil { + _ = f.Close() + return err + } + _ = f.Close() + + // run asmfmt + // run go fmt on whole directory + cmd := exec.Command("asmfmt", "-w", pathSrc) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return err } return nil diff --git a/field/generator/generator_test.go b/field/generator/generator_test.go index 3da09033a3..e490baa03d 100644 --- a/field/generator/generator_test.go +++ b/field/generator/generator_test.go @@ -79,10 +79,10 @@ func TestIntegration(t *testing.T) { moduli["e_nocarry_edge_0127"] = "170141183460469231731687303715884105727" moduli["e_nocarry_edge_1279"] = "10407932194664399081925240327364085538615262247266704805319112350403608059673360298012239441732324184842421613954281007791383566248323464908139906605677320762924129509389220345773183349661583550472959420547689811211693677147548478866962501384438260291732348885311160828538416585028255604666224831890918801847068222203140521026698435488732958028878050869736186900714720710555703168729087" - assert.NoError(GenerateCommonASM(2, asmDir, false, false)) - assert.NoError(GenerateCommonASM(3, asmDir, false, false)) - assert.NoError(GenerateCommonASM(7, asmDir, false, false)) - assert.NoError(GenerateCommonASM(8, asmDir, false, false)) + assert.NoError(GenerateAMD64(2, asmDir, false)) + assert.NoError(GenerateAMD64(3, asmDir, false)) + assert.NoError(GenerateAMD64(7, asmDir, false)) + assert.NoError(GenerateAMD64(8, asmDir, false)) for elementName, modulus := range moduli { var fIntegration *field.FieldConfig diff --git a/field/goff/cmd/root.go b/field/goff/cmd/root.go index 07a21bbc2f..c2ae86809e 100644 --- a/field/goff/cmd/root.go +++ b/field/goff/cmd/root.go @@ -73,9 +73,17 @@ func cmdGenerate(cmd *cobra.Command, args []string) { } asmDir := filepath.Join(fOutputDir, "asm") - if err := generator.GenerateCommonASM(F.NbWords, asmDir, F.GenerateVectorOpsAMD64, F.GenerateOpsARM64); err != nil { - fmt.Printf("\n%s\n", err.Error()) - os.Exit(-1) + if F.GenerateOpsAMD64 { + if err := generator.GenerateAMD64(F.NbWords, asmDir, F.GenerateVectorOpsAMD64); err != nil { + fmt.Printf("\n%s\n", err.Error()) + os.Exit(-1) + } + } + if F.GenerateOpsARM64 { + if err := generator.GenerateARM64(F.NbWords, asmDir, F.GenerateVectorOpsARM64); err != nil { + fmt.Printf("\n%s\n", err.Error()) + os.Exit(-1) + } } if err := generator.GenerateFF(F, fOutputDir, asmDir, "asm"); err != nil { diff --git a/internal/generator/main.go b/internal/generator/main.go index a5232b4b7f..45fbb4fe20 100644 --- a/internal/generator/main.go +++ b/internal/generator/main.go @@ -52,41 +52,14 @@ func main() { asmDirIncludePath := filepath.Join(baseDir, "..", "field", "asm") // generate common assembly files depending on field number of words - mCommon := make(map[int]bool) - mVec := make(map[int]bool) - mArm := make(map[int]bool) - - for i, conf := range config.Curves { - var err error - // generate base field - conf.Fp, err = field.NewFieldConfig("fp", "Element", conf.FpModulus, true) - assertNoError(err) - - conf.Fr, err = field.NewFieldConfig("fr", "Element", conf.FrModulus, !conf.Equal(config.STARK_CURVE)) - assertNoError(err) - - mCommon[conf.Fr.NbWords] = true - mCommon[conf.Fp.NbWords] = true - - if conf.Fr.GenerateVectorOpsAMD64 { - mVec[conf.Fr.NbWords] = true - } - if conf.Fp.GenerateVectorOpsAMD64 { - mVec[conf.Fp.NbWords] = true - } - if conf.Fr.GenerateOpsARM64 { - mArm[conf.Fr.NbWords] = true - } - if conf.Fp.GenerateOpsARM64 { - mArm[conf.Fp.NbWords] = true - } - - config.Curves[i] = conf - } + assertNoError(generator.GenerateAMD64(4, asmDirBuildPath, true)) + assertNoError(generator.GenerateAMD64(5, asmDirBuildPath, false)) + assertNoError(generator.GenerateAMD64(6, asmDirBuildPath, false)) + assertNoError(generator.GenerateAMD64(10, asmDirBuildPath, false)) + assertNoError(generator.GenerateAMD64(12, asmDirBuildPath, false)) - for nbWords := range mCommon { - assertNoError(generator.GenerateCommonASM(nbWords, asmDirBuildPath, mVec[nbWords], mArm[nbWords])) - } + assertNoError(generator.GenerateARM64(4, asmDirBuildPath, false)) + assertNoError(generator.GenerateARM64(6, asmDirBuildPath, false)) var wg sync.WaitGroup for _, conf := range config.Curves { @@ -95,6 +68,14 @@ func main() { go func(conf config.Curve) { defer wg.Done() + var err error + + conf.Fp, err = field.NewFieldConfig("fp", "Element", conf.FpModulus, true) + assertNoError(err) + + conf.Fr, err = field.NewFieldConfig("fr", "Element", conf.FrModulus, !conf.Equal(config.STARK_CURVE)) + assertNoError(err) + curveDir := filepath.Join(baseDir, "ecc", conf.Name) conf.FpUnusedBits = 64 - (conf.Fp.NbBits % 64) From 08a7afae6216de9176e79920525a27b57a0cbb16 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 21 Oct 2024 15:55:03 -0500 Subject: [PATCH 18/74] feat: add reduce arm64 for test purposes --- ecc/bls12-377/fp/element_arm64.go | 5 ++-- ecc/bls12-377/fp/element_arm64.s | 2 +- ecc/bls12-377/fr/element_arm64.go | 5 ++-- ecc/bls12-377/fr/element_arm64.s | 2 +- ecc/bls12-381/fp/element_arm64.go | 5 ++-- ecc/bls12-381/fp/element_arm64.s | 2 +- ecc/bls12-381/fr/element_arm64.go | 5 ++-- ecc/bls12-381/fr/element_arm64.s | 2 +- ecc/bls24-315/fr/element_arm64.go | 5 ++-- ecc/bls24-315/fr/element_arm64.s | 2 +- ecc/bls24-317/fr/element_arm64.go | 5 ++-- ecc/bls24-317/fr/element_arm64.s | 2 +- ecc/bn254/fp/element_arm64.go | 5 ++-- ecc/bn254/fp/element_arm64.s | 2 +- ecc/bn254/fr/element_arm64.go | 5 ++-- ecc/bn254/fr/element_arm64.s | 2 +- ecc/bn254/fr/mimc/test_vectors/vectors.json | 16 +++++----- ecc/bw6-761/fr/element_arm64.go | 5 ++-- ecc/bw6-761/fr/element_arm64.s | 2 +- ecc/stark-curve/fp/element_arm64.go | 5 ++-- ecc/stark-curve/fp/element_arm64.s | 2 +- ecc/stark-curve/fr/element_arm64.go | 5 ++-- ecc/stark-curve/fr/element_arm64.s | 2 +- field/asm/element_4w_arm64.s | 23 ++++++++++++++ field/asm/element_6w_arm64.s | 30 +++++++++++++++++++ field/generator/asm/arm64/build.go | 1 + field/generator/asm/arm64/element_ops.go | 21 +++++++++++++ .../internal/templates/element/ops_asm.go | 5 ++-- 28 files changed, 118 insertions(+), 55 deletions(-) diff --git a/ecc/bls12-377/fp/element_arm64.go b/ecc/bls12-377/fp/element_arm64.go index 9c9a211f27..9793f7eb28 100644 --- a/ecc/bls12-377/fp/element_arm64.go +++ b/ecc/bls12-377/fp/element_arm64.go @@ -75,6 +75,5 @@ func fromMont(z *Element) { _fromMontGeneric(z) } -func reduce(z *Element) { - _reduceGeneric(z) -} +//go:noescape +func reduce(res *Element) diff --git a/ecc/bls12-377/fp/element_arm64.s b/ecc/bls12-377/fp/element_arm64.s index 62de3f0be7..4c01eca83e 100644 --- a/ecc/bls12-377/fp/element_arm64.s +++ b/ecc/bls12-377/fp/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 4799084555005768587 +// We include the hash to force the Go compiler to recompile: 17561434332277668166 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/bls12-377/fr/element_arm64.go b/ecc/bls12-377/fr/element_arm64.go index 07888f58ba..fce7cbcd94 100644 --- a/ecc/bls12-377/fr/element_arm64.go +++ b/ecc/bls12-377/fr/element_arm64.go @@ -73,6 +73,5 @@ func fromMont(z *Element) { _fromMontGeneric(z) } -func reduce(z *Element) { - _reduceGeneric(z) -} +//go:noescape +func reduce(res *Element) diff --git a/ecc/bls12-377/fr/element_arm64.s b/ecc/bls12-377/fr/element_arm64.s index 3cd8aaa667..75bf9d9d1f 100644 --- a/ecc/bls12-377/fr/element_arm64.s +++ b/ecc/bls12-377/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 18027907654287790676 +// We include the hash to force the Go compiler to recompile: 17105046060840004046 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls12-381/fp/element_arm64.go b/ecc/bls12-381/fp/element_arm64.go index aed4105795..044b606293 100644 --- a/ecc/bls12-381/fp/element_arm64.go +++ b/ecc/bls12-381/fp/element_arm64.go @@ -75,6 +75,5 @@ func fromMont(z *Element) { _fromMontGeneric(z) } -func reduce(z *Element) { - _reduceGeneric(z) -} +//go:noescape +func reduce(res *Element) diff --git a/ecc/bls12-381/fp/element_arm64.s b/ecc/bls12-381/fp/element_arm64.s index 62de3f0be7..4c01eca83e 100644 --- a/ecc/bls12-381/fp/element_arm64.s +++ b/ecc/bls12-381/fp/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 4799084555005768587 +// We include the hash to force the Go compiler to recompile: 17561434332277668166 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/bls12-381/fr/element_arm64.go b/ecc/bls12-381/fr/element_arm64.go index e3f8d21ce7..e82da3dad7 100644 --- a/ecc/bls12-381/fr/element_arm64.go +++ b/ecc/bls12-381/fr/element_arm64.go @@ -73,6 +73,5 @@ func fromMont(z *Element) { _fromMontGeneric(z) } -func reduce(z *Element) { - _reduceGeneric(z) -} +//go:noescape +func reduce(res *Element) diff --git a/ecc/bls12-381/fr/element_arm64.s b/ecc/bls12-381/fr/element_arm64.s index 3cd8aaa667..75bf9d9d1f 100644 --- a/ecc/bls12-381/fr/element_arm64.s +++ b/ecc/bls12-381/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 18027907654287790676 +// We include the hash to force the Go compiler to recompile: 17105046060840004046 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-315/fr/element_arm64.go b/ecc/bls24-315/fr/element_arm64.go index 7b52729f61..cdf726a8eb 100644 --- a/ecc/bls24-315/fr/element_arm64.go +++ b/ecc/bls24-315/fr/element_arm64.go @@ -73,6 +73,5 @@ func fromMont(z *Element) { _fromMontGeneric(z) } -func reduce(z *Element) { - _reduceGeneric(z) -} +//go:noescape +func reduce(res *Element) diff --git a/ecc/bls24-315/fr/element_arm64.s b/ecc/bls24-315/fr/element_arm64.s index 3cd8aaa667..75bf9d9d1f 100644 --- a/ecc/bls24-315/fr/element_arm64.s +++ b/ecc/bls24-315/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 18027907654287790676 +// We include the hash to force the Go compiler to recompile: 17105046060840004046 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-317/fr/element_arm64.go b/ecc/bls24-317/fr/element_arm64.go index ab8a390bdf..7a483d1551 100644 --- a/ecc/bls24-317/fr/element_arm64.go +++ b/ecc/bls24-317/fr/element_arm64.go @@ -73,6 +73,5 @@ func fromMont(z *Element) { _fromMontGeneric(z) } -func reduce(z *Element) { - _reduceGeneric(z) -} +//go:noescape +func reduce(res *Element) diff --git a/ecc/bls24-317/fr/element_arm64.s b/ecc/bls24-317/fr/element_arm64.s index 3cd8aaa667..75bf9d9d1f 100644 --- a/ecc/bls24-317/fr/element_arm64.s +++ b/ecc/bls24-317/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 18027907654287790676 +// We include the hash to force the Go compiler to recompile: 17105046060840004046 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fp/element_arm64.go b/ecc/bn254/fp/element_arm64.go index bd05e23e63..dfb583df86 100644 --- a/ecc/bn254/fp/element_arm64.go +++ b/ecc/bn254/fp/element_arm64.go @@ -73,6 +73,5 @@ func fromMont(z *Element) { _fromMontGeneric(z) } -func reduce(z *Element) { - _reduceGeneric(z) -} +//go:noescape +func reduce(res *Element) diff --git a/ecc/bn254/fp/element_arm64.s b/ecc/bn254/fp/element_arm64.s index 3cd8aaa667..75bf9d9d1f 100644 --- a/ecc/bn254/fp/element_arm64.s +++ b/ecc/bn254/fp/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 18027907654287790676 +// We include the hash to force the Go compiler to recompile: 17105046060840004046 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fr/element_arm64.go b/ecc/bn254/fr/element_arm64.go index a0ccad8e0e..8b28f5123b 100644 --- a/ecc/bn254/fr/element_arm64.go +++ b/ecc/bn254/fr/element_arm64.go @@ -73,6 +73,5 @@ func fromMont(z *Element) { _fromMontGeneric(z) } -func reduce(z *Element) { - _reduceGeneric(z) -} +//go:noescape +func reduce(res *Element) diff --git a/ecc/bn254/fr/element_arm64.s b/ecc/bn254/fr/element_arm64.s index 3cd8aaa667..75bf9d9d1f 100644 --- a/ecc/bn254/fr/element_arm64.s +++ b/ecc/bn254/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 18027907654287790676 +// We include the hash to force the Go compiler to recompile: 17105046060840004046 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fr/mimc/test_vectors/vectors.json b/ecc/bn254/fr/mimc/test_vectors/vectors.json index d0961fe637..cc1cb8eb53 100644 --- a/ecc/bn254/fr/mimc/test_vectors/vectors.json +++ b/ecc/bn254/fr/mimc/test_vectors/vectors.json @@ -3,13 +3,13 @@ "in": [ "0x105afe02a0f7648bee1669b05bf7ae69a37dbb6c86ebbee325dffe97ac1f8e64" ], - "out": "0x263b9e754e6c611d646e65b16c48f51ab7bc0abedfae9c6ea04e2814ed28daf4" + "out": "0x263b9e754e6c611d646e65b16c48f518b7bc0abedfae9c6ea04e2814ed28daf4" }, { "in": [ "0xbc35f0589078e34d9139a357175d0e74b843e3de2f56bb6a5f18032ff3f627" ], - "out": "0x103e2c8f50dec5248fd68f9778429252364ff239b123977a697356082f524f25" + "out": "0x24157ce0622879c8e344dbbc67835d3102d033cce465a2b494bb2eed60fb7553" }, { "in": [ @@ -31,7 +31,7 @@ "0x59323b0ab7043f559674eba263da812eae9e933b0c1bad55f8118d0caaa7479", "0x16b161c8de7184ccc6b1b6fcddb562789a68eeaec174376f1157dfb3db310787" ], - "out": "0x118e5255aabe7a3b6a5dde6ca28de461d36f802653885c665745fc4e6ca0f709" + "out": "0xe2b910ebf0c6c0e50d2d5e2703c45973805346cafb9f19200fd7645ff47048c" }, { "in": [ @@ -39,7 +39,7 @@ "0x18ff125903dc8352ca63c7a436f0425b4b7ddf7e487fb9ffd30f151993571b57", "0x2cbfaa412f4b612d611acaab79a9e1c06b7094d8754fdbc085db28f2e4dd09ab" ], - "out": "0x25fa55a9896d91d9617d9512e061d754336816f748bf07566591ec5cf4680dd" + "out": "0x1e2e2466b7fb3569539d639125ac5273fd0029eecaebeb0d3722133834429a62" }, { "in": [ @@ -48,7 +48,7 @@ "0x22e63a3eb565d13c42c7d520c7b6112534b1c666653452f743b80bcc2d878455", "0x96dff377f354f792685a7e740e3024409c24a379425ff63e3ce320b1e9bc471" ], - "out": "0x18ea2fd58b4f3274193f3e73bded685426f114a8e7bc373c1aee3e6f0125787b" + "out": "0x16a858453a534784b95fb52f40e203db21c51a6bb9c13599c0d7aaa802e190aa" }, { "in": [ @@ -57,7 +57,7 @@ "0x12f4c27e5a2e80dd67fb33928c4e6219a8bdc89b498ed32acb02d725cec90076", "0x1d6b52c237f0f74f0c50755627eed2610608488b54b0a3941a4623b1d435232a" ], - "out": "0x2f237dea4570779296e2866383740b8e9ccf59577f8ff729880dadb58ae34d47" + "out": "0x2759e8bb5393ea5eb110eb977de64c01f6cbf9a434ed7d0679c8a937f39e19c" }, { "in": [ @@ -67,7 +67,7 @@ "0x2d4232cb721888f71997377de5ca195a5ae03d3eb9c87d2c04ef3664759036da", "0x2f623ee75518430e291d42e7aaa75f5291a1bbfed125426d39270046a26be35a" ], - "out": "0x6246dee7e2d9560a074c50a06e6525e4a58395cea4a893c49d71e373f19b9d6" + "out": "0x2a1dd4b9af5fc4055302c33d7a24f0070c660d27984d3169757f83b2d8c5a52" }, { "in": [ @@ -77,6 +77,6 @@ "0x8e0ddb80366c4c6c7dcb9090f4862d64ef40677d324a76a82e06ca33ad29a09", "0x170e8c954ca7e6526b743e92f796488afe5083a9c549358f730659c3e1cdbafa" ], - "out": "0x1a2e7cffb5183898a8f4f6d4699bc272665ebffbb9d095576d2e21c45f012358" + "out": "0x1d54497fe996571b75b403593832da3d668a825b0978dc57afd38b886ef41e4c" } ] \ No newline at end of file diff --git a/ecc/bw6-761/fr/element_arm64.go b/ecc/bw6-761/fr/element_arm64.go index 52262cdba6..7b389c515b 100644 --- a/ecc/bw6-761/fr/element_arm64.go +++ b/ecc/bw6-761/fr/element_arm64.go @@ -75,6 +75,5 @@ func fromMont(z *Element) { _fromMontGeneric(z) } -func reduce(z *Element) { - _reduceGeneric(z) -} +//go:noescape +func reduce(res *Element) diff --git a/ecc/bw6-761/fr/element_arm64.s b/ecc/bw6-761/fr/element_arm64.s index 62de3f0be7..4c01eca83e 100644 --- a/ecc/bw6-761/fr/element_arm64.s +++ b/ecc/bw6-761/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 4799084555005768587 +// We include the hash to force the Go compiler to recompile: 17561434332277668166 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/stark-curve/fp/element_arm64.go b/ecc/stark-curve/fp/element_arm64.go index c77bc2ea84..776a544644 100644 --- a/ecc/stark-curve/fp/element_arm64.go +++ b/ecc/stark-curve/fp/element_arm64.go @@ -73,6 +73,5 @@ func fromMont(z *Element) { _fromMontGeneric(z) } -func reduce(z *Element) { - _reduceGeneric(z) -} +//go:noescape +func reduce(res *Element) diff --git a/ecc/stark-curve/fp/element_arm64.s b/ecc/stark-curve/fp/element_arm64.s index 3cd8aaa667..75bf9d9d1f 100644 --- a/ecc/stark-curve/fp/element_arm64.s +++ b/ecc/stark-curve/fp/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 18027907654287790676 +// We include the hash to force the Go compiler to recompile: 17105046060840004046 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fr/element_arm64.go b/ecc/stark-curve/fr/element_arm64.go index 474236e22f..69e7dd8b03 100644 --- a/ecc/stark-curve/fr/element_arm64.go +++ b/ecc/stark-curve/fr/element_arm64.go @@ -73,6 +73,5 @@ func fromMont(z *Element) { _fromMontGeneric(z) } -func reduce(z *Element) { - _reduceGeneric(z) -} +//go:noescape +func reduce(res *Element) diff --git a/ecc/stark-curve/fr/element_arm64.s b/ecc/stark-curve/fr/element_arm64.s index 3cd8aaa667..75bf9d9d1f 100644 --- a/ecc/stark-curve/fr/element_arm64.s +++ b/ecc/stark-curve/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 18027907654287790676 +// We include the hash to force the Go compiler to recompile: 17105046060840004046 #include "../../../field/asm/element_4w_arm64.s" diff --git a/field/asm/element_4w_arm64.s b/field/asm/element_4w_arm64.s index 071e2495c3..58eccf4142 100644 --- a/field/asm/element_4w_arm64.s +++ b/field/asm/element_4w_arm64.s @@ -138,3 +138,26 @@ TEXT ·mul(SB), NOFRAME|NOSPLIT, $0-24 CSEL CS, R10, R14, R14 STP (R13, R14), 16(R0) RET + +// reduce(res *Element) +TEXT ·reduce(SB), NOFRAME|NOSPLIT, $0-8 + LDP ·qElement+0(SB), (R4, R5) + LDP ·qElement+16(SB), (R6, R7) + MOVD res+0(FP), R8 + LDP 0(R8), (R0, R1) + LDP 16(R8), (R2, R3) + + // q = t - q + SUBS R4, R0, R4 + SBCS R5, R1, R5 + SBCS R6, R2, R6 + SBCS R7, R3, R7 + + // if no borrow, return q, else return t + CSEL CS, R4, R0, R0 + CSEL CS, R5, R1, R1 + STP (R0, R1), 0(R8) + CSEL CS, R6, R2, R2 + CSEL CS, R7, R3, R3 + STP (R2, R3), 16(R8) + RET diff --git a/field/asm/element_6w_arm64.s b/field/asm/element_6w_arm64.s index bcba9ee6b1..7b4946b3f9 100644 --- a/field/asm/element_6w_arm64.s +++ b/field/asm/element_6w_arm64.s @@ -188,3 +188,33 @@ TEXT ·mul(SB), NOFRAME|NOSPLIT, $0-24 CSEL CS, R14, R21, R21 STP (R20, R21), 32(R0) RET + +// reduce(res *Element) +TEXT ·reduce(SB), NOFRAME|NOSPLIT, $0-8 + LDP ·qElement+0(SB), (R6, R7) + LDP ·qElement+16(SB), (R8, R9) + LDP ·qElement+32(SB), (R10, R11) + MOVD res+0(FP), R12 + LDP 0(R12), (R0, R1) + LDP 16(R12), (R2, R3) + LDP 32(R12), (R4, R5) + + // q = t - q + SUBS R6, R0, R6 + SBCS R7, R1, R7 + SBCS R8, R2, R8 + SBCS R9, R3, R9 + SBCS R10, R4, R10 + SBCS R11, R5, R11 + + // if no borrow, return q, else return t + CSEL CS, R6, R0, R0 + CSEL CS, R7, R1, R1 + STP (R0, R1), 0(R12) + CSEL CS, R8, R2, R2 + CSEL CS, R9, R3, R3 + STP (R2, R3), 16(R12) + CSEL CS, R10, R4, R4 + CSEL CS, R11, R5, R5 + STP (R4, R5), 32(R12) + RET diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index f3b23f31d6..cba7bde8bc 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -109,6 +109,7 @@ func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { f.generateButterfly() f.generateMul() + f.generateReduce() return nil } diff --git a/field/generator/asm/arm64/element_ops.go b/field/generator/asm/arm64/element_ops.go index 8d427a1286..354cd68d99 100644 --- a/field/generator/asm/arm64/element_ops.go +++ b/field/generator/asm/arm64/element_ops.go @@ -64,6 +64,27 @@ func (f *FFArm64) generateButterfly() { f.RET() } +func (f *FFArm64) generateReduce() { + f.Comment("reduce(res *Element)") + registers := f.FnHeader("reduce", 0, 8) + defer f.AssertCleanStack(0, 0) + + // registers + t := registers.PopN(f.NbWords) + q := registers.PopN(f.NbWords) + rPtr := registers.Pop() + + for i := 0; i < f.NbWords; i += 2 { + f.LDP(f.qAt(i), q[i], q[i+1]) + } + + f.MOVD("res+0(FP)", rPtr) + f.load(rPtr, t) + f.reduceAndStore(t, q, rPtr) + + f.RET() +} + func (f *FFArm64) generateMul() { f.Comment("mul(res, x, y *Element)") f.Comment("Algorithm 2 of Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS") diff --git a/field/generator/internal/templates/element/ops_asm.go b/field/generator/internal/templates/element/ops_asm.go index 91c0034c57..271d7d1165 100644 --- a/field/generator/internal/templates/element/ops_asm.go +++ b/field/generator/internal/templates/element/ops_asm.go @@ -108,9 +108,8 @@ func fromMont(z *{{.ElementName}} ) { _fromMontGeneric(z) } -func reduce(z *{{.ElementName}}) { - _reduceGeneric(z) -} +//go:noescape +func reduce(res *{{.ElementName}}) ` const IncludeASM = ` From 771ce5428f5965cfacf9df6d45341a77dfd6e4b1 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 22 Oct 2024 08:35:24 -0500 Subject: [PATCH 19/74] feat: restore vectors.json mimc --- ecc/bn254/fr/mimc/test_vectors/vectors.json | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ecc/bn254/fr/mimc/test_vectors/vectors.json b/ecc/bn254/fr/mimc/test_vectors/vectors.json index cc1cb8eb53..d0961fe637 100644 --- a/ecc/bn254/fr/mimc/test_vectors/vectors.json +++ b/ecc/bn254/fr/mimc/test_vectors/vectors.json @@ -3,13 +3,13 @@ "in": [ "0x105afe02a0f7648bee1669b05bf7ae69a37dbb6c86ebbee325dffe97ac1f8e64" ], - "out": "0x263b9e754e6c611d646e65b16c48f518b7bc0abedfae9c6ea04e2814ed28daf4" + "out": "0x263b9e754e6c611d646e65b16c48f51ab7bc0abedfae9c6ea04e2814ed28daf4" }, { "in": [ "0xbc35f0589078e34d9139a357175d0e74b843e3de2f56bb6a5f18032ff3f627" ], - "out": "0x24157ce0622879c8e344dbbc67835d3102d033cce465a2b494bb2eed60fb7553" + "out": "0x103e2c8f50dec5248fd68f9778429252364ff239b123977a697356082f524f25" }, { "in": [ @@ -31,7 +31,7 @@ "0x59323b0ab7043f559674eba263da812eae9e933b0c1bad55f8118d0caaa7479", "0x16b161c8de7184ccc6b1b6fcddb562789a68eeaec174376f1157dfb3db310787" ], - "out": "0xe2b910ebf0c6c0e50d2d5e2703c45973805346cafb9f19200fd7645ff47048c" + "out": "0x118e5255aabe7a3b6a5dde6ca28de461d36f802653885c665745fc4e6ca0f709" }, { "in": [ @@ -39,7 +39,7 @@ "0x18ff125903dc8352ca63c7a436f0425b4b7ddf7e487fb9ffd30f151993571b57", "0x2cbfaa412f4b612d611acaab79a9e1c06b7094d8754fdbc085db28f2e4dd09ab" ], - "out": "0x1e2e2466b7fb3569539d639125ac5273fd0029eecaebeb0d3722133834429a62" + "out": "0x25fa55a9896d91d9617d9512e061d754336816f748bf07566591ec5cf4680dd" }, { "in": [ @@ -48,7 +48,7 @@ "0x22e63a3eb565d13c42c7d520c7b6112534b1c666653452f743b80bcc2d878455", "0x96dff377f354f792685a7e740e3024409c24a379425ff63e3ce320b1e9bc471" ], - "out": "0x16a858453a534784b95fb52f40e203db21c51a6bb9c13599c0d7aaa802e190aa" + "out": "0x18ea2fd58b4f3274193f3e73bded685426f114a8e7bc373c1aee3e6f0125787b" }, { "in": [ @@ -57,7 +57,7 @@ "0x12f4c27e5a2e80dd67fb33928c4e6219a8bdc89b498ed32acb02d725cec90076", "0x1d6b52c237f0f74f0c50755627eed2610608488b54b0a3941a4623b1d435232a" ], - "out": "0x2759e8bb5393ea5eb110eb977de64c01f6cbf9a434ed7d0679c8a937f39e19c" + "out": "0x2f237dea4570779296e2866383740b8e9ccf59577f8ff729880dadb58ae34d47" }, { "in": [ @@ -67,7 +67,7 @@ "0x2d4232cb721888f71997377de5ca195a5ae03d3eb9c87d2c04ef3664759036da", "0x2f623ee75518430e291d42e7aaa75f5291a1bbfed125426d39270046a26be35a" ], - "out": "0x2a1dd4b9af5fc4055302c33d7a24f0070c660d27984d3169757f83b2d8c5a52" + "out": "0x6246dee7e2d9560a074c50a06e6525e4a58395cea4a893c49d71e373f19b9d6" }, { "in": [ @@ -77,6 +77,6 @@ "0x8e0ddb80366c4c6c7dcb9090f4862d64ef40677d324a76a82e06ca33ad29a09", "0x170e8c954ca7e6526b743e92f796488afe5083a9c549358f730659c3e1cdbafa" ], - "out": "0x1d54497fe996571b75b403593832da3d668a825b0978dc57afd38b886ef41e4c" + "out": "0x1a2e7cffb5183898a8f4f6d4699bc272665ebffbb9d095576d2e21c45f012358" } ] \ No newline at end of file From e303fc238c9fe2328e288c8abd97f4a4abcf511c Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 22 Oct 2024 08:36:46 -0500 Subject: [PATCH 20/74] style: add trace in mimc generate --- ecc/bn254/fr/mimc/test_vectors/main.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ecc/bn254/fr/mimc/test_vectors/main.go b/ecc/bn254/fr/mimc/test_vectors/main.go index e5cdef6c1a..b2e306a590 100644 --- a/ecc/bn254/fr/mimc/test_vectors/main.go +++ b/ecc/bn254/fr/mimc/test_vectors/main.go @@ -3,9 +3,10 @@ package main import ( "encoding/json" "fmt" + "os" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" - "os" ) type numericalMiMCTestCase struct { @@ -22,6 +23,7 @@ func assertNoError(err error) { //go:generate go run main.go func main() { + fmt.Println("generating test vectors for MiMC...") var tests []numericalMiMCTestCase bytes, err := os.ReadFile("./vectors.json") From 548dfd14c4fa64eabea06ee1d3a27499ed976d99 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 22 Oct 2024 08:54:21 -0500 Subject: [PATCH 21/74] feat: update build tags for 32 bit target --- ecc/bls12-377/fp/element_purego.go | 2 +- ecc/bls12-377/fr/element_purego.go | 2 +- ecc/bls12-377/fr/vector_purego.go | 2 +- ecc/bls12-381/fp/element_purego.go | 2 +- ecc/bls12-381/fr/element_purego.go | 2 +- ecc/bls12-381/fr/vector_purego.go | 2 +- ecc/bls24-315/fp/element_purego.go | 2 +- ecc/bls24-315/fr/element_purego.go | 2 +- ecc/bls24-315/fr/vector_purego.go | 2 +- ecc/bls24-317/fp/element_purego.go | 2 +- ecc/bls24-317/fr/element_purego.go | 2 +- ecc/bls24-317/fr/vector_purego.go | 2 +- ecc/bn254/fp/element_purego.go | 2 +- ecc/bn254/fp/vector_purego.go | 2 +- ecc/bn254/fr/element_purego.go | 2 +- ecc/bn254/fr/vector_purego.go | 2 +- ecc/bw6-633/fp/element_purego.go | 2 +- ecc/bw6-633/fr/element_purego.go | 2 +- ecc/bw6-761/fp/element_purego.go | 2 +- ecc/bw6-761/fr/element_purego.go | 2 +- ecc/stark-curve/fp/element_purego.go | 2 +- ecc/stark-curve/fp/vector_purego.go | 2 +- ecc/stark-curve/fr/element_purego.go | 2 +- ecc/stark-curve/fr/vector_purego.go | 2 +- field/generator/generator.go | 8 ++++---- 25 files changed, 28 insertions(+), 28 deletions(-) diff --git a/ecc/bls12-377/fp/element_purego.go b/ecc/bls12-377/fp/element_purego.go index be2027de13..93afd90757 100644 --- a/ecc/bls12-377/fp/element_purego.go +++ b/ecc/bls12-377/fp/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego +//go:build purego || (!amd64 && !arm64) // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-377/fr/element_purego.go b/ecc/bls12-377/fr/element_purego.go index aa0c3e9a75..60de099977 100644 --- a/ecc/bls12-377/fr/element_purego.go +++ b/ecc/bls12-377/fr/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego +//go:build purego || (!amd64 && !arm64) // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-377/fr/vector_purego.go b/ecc/bls12-377/fr/vector_purego.go index d09c259806..84f86a40b1 100644 --- a/ecc/bls12-377/fr/vector_purego.go +++ b/ecc/bls12-377/fr/vector_purego.go @@ -1,4 +1,4 @@ -//go:build purego || arm64 +//go:build purego || !amd64 // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-381/fp/element_purego.go b/ecc/bls12-381/fp/element_purego.go index 511c52f1f1..ac96223a0a 100644 --- a/ecc/bls12-381/fp/element_purego.go +++ b/ecc/bls12-381/fp/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego +//go:build purego || (!amd64 && !arm64) // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-381/fr/element_purego.go b/ecc/bls12-381/fr/element_purego.go index be6e50e1ff..193edfb5d2 100644 --- a/ecc/bls12-381/fr/element_purego.go +++ b/ecc/bls12-381/fr/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego +//go:build purego || (!amd64 && !arm64) // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-381/fr/vector_purego.go b/ecc/bls12-381/fr/vector_purego.go index d09c259806..84f86a40b1 100644 --- a/ecc/bls12-381/fr/vector_purego.go +++ b/ecc/bls12-381/fr/vector_purego.go @@ -1,4 +1,4 @@ -//go:build purego || arm64 +//go:build purego || !amd64 // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-315/fp/element_purego.go b/ecc/bls24-315/fp/element_purego.go index 92806a73c7..f5c6690f7e 100644 --- a/ecc/bls24-315/fp/element_purego.go +++ b/ecc/bls24-315/fp/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego || arm64 +//go:build purego || !amd64 // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-315/fr/element_purego.go b/ecc/bls24-315/fr/element_purego.go index 28d51a0862..a93f64f06e 100644 --- a/ecc/bls24-315/fr/element_purego.go +++ b/ecc/bls24-315/fr/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego +//go:build purego || (!amd64 && !arm64) // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-315/fr/vector_purego.go b/ecc/bls24-315/fr/vector_purego.go index d09c259806..84f86a40b1 100644 --- a/ecc/bls24-315/fr/vector_purego.go +++ b/ecc/bls24-315/fr/vector_purego.go @@ -1,4 +1,4 @@ -//go:build purego || arm64 +//go:build purego || !amd64 // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-317/fp/element_purego.go b/ecc/bls24-317/fp/element_purego.go index 9c855b3475..0f21692f40 100644 --- a/ecc/bls24-317/fp/element_purego.go +++ b/ecc/bls24-317/fp/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego || arm64 +//go:build purego || !amd64 // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-317/fr/element_purego.go b/ecc/bls24-317/fr/element_purego.go index af2c85d2cd..c9b8bd4616 100644 --- a/ecc/bls24-317/fr/element_purego.go +++ b/ecc/bls24-317/fr/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego +//go:build purego || (!amd64 && !arm64) // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-317/fr/vector_purego.go b/ecc/bls24-317/fr/vector_purego.go index d09c259806..84f86a40b1 100644 --- a/ecc/bls24-317/fr/vector_purego.go +++ b/ecc/bls24-317/fr/vector_purego.go @@ -1,4 +1,4 @@ -//go:build purego || arm64 +//go:build purego || !amd64 // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bn254/fp/element_purego.go b/ecc/bn254/fp/element_purego.go index 00300a2189..ded61b6088 100644 --- a/ecc/bn254/fp/element_purego.go +++ b/ecc/bn254/fp/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego +//go:build purego || (!amd64 && !arm64) // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bn254/fp/vector_purego.go b/ecc/bn254/fp/vector_purego.go index c6d37d76f4..fc0f66e2a3 100644 --- a/ecc/bn254/fp/vector_purego.go +++ b/ecc/bn254/fp/vector_purego.go @@ -1,4 +1,4 @@ -//go:build purego || arm64 +//go:build purego || !amd64 // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bn254/fr/element_purego.go b/ecc/bn254/fr/element_purego.go index 78b5d8a64c..034ee75ffe 100644 --- a/ecc/bn254/fr/element_purego.go +++ b/ecc/bn254/fr/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego +//go:build purego || (!amd64 && !arm64) // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bn254/fr/vector_purego.go b/ecc/bn254/fr/vector_purego.go index d09c259806..84f86a40b1 100644 --- a/ecc/bn254/fr/vector_purego.go +++ b/ecc/bn254/fr/vector_purego.go @@ -1,4 +1,4 @@ -//go:build purego || arm64 +//go:build purego || !amd64 // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-633/fp/element_purego.go b/ecc/bw6-633/fp/element_purego.go index 6ba677acfe..811df961eb 100644 --- a/ecc/bw6-633/fp/element_purego.go +++ b/ecc/bw6-633/fp/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego || arm64 +//go:build purego || !amd64 // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-633/fr/element_purego.go b/ecc/bw6-633/fr/element_purego.go index 4d38ad7730..93ea6339b3 100644 --- a/ecc/bw6-633/fr/element_purego.go +++ b/ecc/bw6-633/fr/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego || arm64 +//go:build purego || !amd64 // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-761/fp/element_purego.go b/ecc/bw6-761/fp/element_purego.go index 21b07566fd..128e162748 100644 --- a/ecc/bw6-761/fp/element_purego.go +++ b/ecc/bw6-761/fp/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego || arm64 +//go:build purego || !amd64 // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-761/fr/element_purego.go b/ecc/bw6-761/fr/element_purego.go index a67b43b678..5d56b28300 100644 --- a/ecc/bw6-761/fr/element_purego.go +++ b/ecc/bw6-761/fr/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego +//go:build purego || (!amd64 && !arm64) // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/stark-curve/fp/element_purego.go b/ecc/stark-curve/fp/element_purego.go index 380a94c00c..5308610f0b 100644 --- a/ecc/stark-curve/fp/element_purego.go +++ b/ecc/stark-curve/fp/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego +//go:build purego || (!amd64 && !arm64) // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/stark-curve/fp/vector_purego.go b/ecc/stark-curve/fp/vector_purego.go index c6d37d76f4..fc0f66e2a3 100644 --- a/ecc/stark-curve/fp/vector_purego.go +++ b/ecc/stark-curve/fp/vector_purego.go @@ -1,4 +1,4 @@ -//go:build purego || arm64 +//go:build purego || !amd64 // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/stark-curve/fr/element_purego.go b/ecc/stark-curve/fr/element_purego.go index 7aea910af6..ae05fe839b 100644 --- a/ecc/stark-curve/fr/element_purego.go +++ b/ecc/stark-curve/fr/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego +//go:build purego || (!amd64 && !arm64) // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/stark-curve/fr/vector_purego.go b/ecc/stark-curve/fr/vector_purego.go index d09c259806..84f86a40b1 100644 --- a/ecc/stark-curve/fr/vector_purego.go +++ b/ecc/stark-curve/fr/vector_purego.go @@ -1,4 +1,4 @@ -//go:build purego || arm64 +//go:build purego || !amd64 // Copyright 2020 ConsenSys Software Inc. // diff --git a/field/generator/generator.go b/field/generator/generator.go index 2b6dd2d9f0..22e0292778 100644 --- a/field/generator/generator.go +++ b/field/generator/generator.go @@ -126,17 +126,17 @@ func GenerateFF(F *config.FieldConfig, outputDir, asmDirBuildPath, asmDirInclude } // purego files have no build tags if we don't generate asm - pureGoBuildTag := "purego" + pureGoBuildTag := "purego || (!amd64 && !arm64)" if !F.GenerateOpsAMD64 && !F.GenerateOpsARM64 { pureGoBuildTag = "" } else if !F.GenerateOpsARM64 { - pureGoBuildTag = "purego || arm64" + pureGoBuildTag = "purego || (!amd64)" } - pureGoVectorBuildTag := "purego" + pureGoVectorBuildTag := "purego || (!amd64 && !arm64)" if !F.GenerateVectorOpsAMD64 && !F.GenerateVectorOpsARM64 { pureGoVectorBuildTag = "" } else if !F.GenerateVectorOpsARM64 { - pureGoVectorBuildTag = "purego || arm64" + pureGoVectorBuildTag = "purego || (!amd64)" } var g errgroup.Group From 147a62e55adfb2e28b7a62abe86949e1a1f726ce Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 22 Oct 2024 19:42:10 +0000 Subject: [PATCH 22/74] feat: generalize arm64 mul for larger modulus --- ecc/bls12-377/fp/element_arm64.s | 2 +- ecc/bls12-377/fr/element_arm64.s | 2 +- ecc/bls12-381/fp/element_arm64.s | 2 +- ecc/bls12-381/fr/element_arm64.s | 2 +- ecc/bls24-315/fr/element_arm64.s | 2 +- ecc/bls24-317/fr/element_arm64.s | 2 +- ecc/bn254/fp/element_arm64.s | 2 +- ecc/bn254/fr/element_arm64.s | 2 +- ecc/bw6-633/fp/element_purego.go | 2 +- ecc/bw6-761/fp/element_purego.go | 2 +- ecc/bw6-761/fr/element_arm64.s | 2 +- ecc/stark-curve/fp/element_arm64.s | 2 +- ecc/stark-curve/fr/element_arm64.s | 2 +- field/asm/.gitignore | 7 +- field/asm/element_4w_arm64.s | 140 ++++----- field/asm/element_6w_arm64.s | 202 ++++++------- field/generator/asm/arm64/build.go | 4 +- field/generator/asm/arm64/element_ops.go | 267 ------------------ field/generator/config/field_config.go | 2 +- field/generator/generator_test.go | 3 + .../internal/templates/element/ops_asm.go | 6 + go.mod | 2 +- go.sum | 4 + internal/generator/main.go | 2 + 24 files changed, 210 insertions(+), 455 deletions(-) delete mode 100644 field/generator/asm/arm64/element_ops.go diff --git a/ecc/bls12-377/fp/element_arm64.s b/ecc/bls12-377/fp/element_arm64.s index 4c01eca83e..2a3f7d0b2c 100644 --- a/ecc/bls12-377/fp/element_arm64.s +++ b/ecc/bls12-377/fp/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17561434332277668166 +// We include the hash to force the Go compiler to recompile: 15397482240260640864 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/bls12-377/fr/element_arm64.s b/ecc/bls12-377/fr/element_arm64.s index 75bf9d9d1f..5d328815aa 100644 --- a/ecc/bls12-377/fr/element_arm64.s +++ b/ecc/bls12-377/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17105046060840004046 +// We include the hash to force the Go compiler to recompile: 1501560133179981797 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls12-381/fp/element_arm64.s b/ecc/bls12-381/fp/element_arm64.s index 4c01eca83e..2a3f7d0b2c 100644 --- a/ecc/bls12-381/fp/element_arm64.s +++ b/ecc/bls12-381/fp/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17561434332277668166 +// We include the hash to force the Go compiler to recompile: 15397482240260640864 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/bls12-381/fr/element_arm64.s b/ecc/bls12-381/fr/element_arm64.s index 75bf9d9d1f..5d328815aa 100644 --- a/ecc/bls12-381/fr/element_arm64.s +++ b/ecc/bls12-381/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17105046060840004046 +// We include the hash to force the Go compiler to recompile: 1501560133179981797 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-315/fr/element_arm64.s b/ecc/bls24-315/fr/element_arm64.s index 75bf9d9d1f..5d328815aa 100644 --- a/ecc/bls24-315/fr/element_arm64.s +++ b/ecc/bls24-315/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17105046060840004046 +// We include the hash to force the Go compiler to recompile: 1501560133179981797 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-317/fr/element_arm64.s b/ecc/bls24-317/fr/element_arm64.s index 75bf9d9d1f..5d328815aa 100644 --- a/ecc/bls24-317/fr/element_arm64.s +++ b/ecc/bls24-317/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17105046060840004046 +// We include the hash to force the Go compiler to recompile: 1501560133179981797 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fp/element_arm64.s b/ecc/bn254/fp/element_arm64.s index 75bf9d9d1f..5d328815aa 100644 --- a/ecc/bn254/fp/element_arm64.s +++ b/ecc/bn254/fp/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17105046060840004046 +// We include the hash to force the Go compiler to recompile: 1501560133179981797 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fr/element_arm64.s b/ecc/bn254/fr/element_arm64.s index 75bf9d9d1f..5d328815aa 100644 --- a/ecc/bn254/fr/element_arm64.s +++ b/ecc/bn254/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17105046060840004046 +// We include the hash to force the Go compiler to recompile: 1501560133179981797 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bw6-633/fp/element_purego.go b/ecc/bw6-633/fp/element_purego.go index 811df961eb..637ecd9d67 100644 --- a/ecc/bw6-633/fp/element_purego.go +++ b/ecc/bw6-633/fp/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego || !amd64 +//go:build purego || (!amd64 && !arm64) // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-761/fp/element_purego.go b/ecc/bw6-761/fp/element_purego.go index 128e162748..4338d90c2b 100644 --- a/ecc/bw6-761/fp/element_purego.go +++ b/ecc/bw6-761/fp/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego || !amd64 +//go:build purego || (!amd64 && !arm64) // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-761/fr/element_arm64.s b/ecc/bw6-761/fr/element_arm64.s index 4c01eca83e..2a3f7d0b2c 100644 --- a/ecc/bw6-761/fr/element_arm64.s +++ b/ecc/bw6-761/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17561434332277668166 +// We include the hash to force the Go compiler to recompile: 15397482240260640864 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/stark-curve/fp/element_arm64.s b/ecc/stark-curve/fp/element_arm64.s index 75bf9d9d1f..5d328815aa 100644 --- a/ecc/stark-curve/fp/element_arm64.s +++ b/ecc/stark-curve/fp/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17105046060840004046 +// We include the hash to force the Go compiler to recompile: 1501560133179981797 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fr/element_arm64.s b/ecc/stark-curve/fr/element_arm64.s index 75bf9d9d1f..5d328815aa 100644 --- a/ecc/stark-curve/fr/element_arm64.s +++ b/ecc/stark-curve/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17105046060840004046 +// We include the hash to force the Go compiler to recompile: 1501560133179981797 #include "../../../field/asm/element_4w_arm64.s" diff --git a/field/asm/.gitignore b/field/asm/.gitignore index 7c22f7f933..d534769fcb 100644 --- a/field/asm/.gitignore +++ b/field/asm/.gitignore @@ -3,4 +3,9 @@ element_2w_amd64.s element_3w_amd64.s element_7w_amd64.s element_8w_amd64.s -*.h \ No newline at end of file +*.h + +element_2w_arm64.s +element_3w_arm64.s +element_7w_arm64.s +element_8w_arm64.s \ No newline at end of file diff --git a/field/asm/element_4w_arm64.s b/field/asm/element_4w_arm64.s index 58eccf4142..fce96e3000 100644 --- a/field/asm/element_4w_arm64.s +++ b/field/asm/element_4w_arm64.s @@ -54,89 +54,89 @@ TEXT ·Butterfly(SB), NOFRAME|NOSPLIT, $0-16 // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 TEXT ·mul(SB), NOFRAME|NOSPLIT, $0-24 #define DIVSHIFT() \ - MUL R7, R17, R0 \ - ADDS R0, R11, R11 \ - MUL R8, R17, R0 \ - ADCS R0, R12, R12 \ - MUL R9, R17, R0 \ - ADCS R0, R13, R13 \ - MUL R10, R17, R0 \ - ADCS R0, R14, R14 \ - ADC R15, ZR, R15 \ - UMULH R7, R17, R0 \ - ADDS R0, R12, R11 \ - UMULH R8, R17, R0 \ - ADCS R0, R13, R12 \ - UMULH R9, R17, R0 \ - ADCS R0, R14, R13 \ - UMULH R10, R17, R0 \ - ADCS R0, R15, R14 \ + MUL R13, R12, R0 \ + ADDS R0, R6, R6 \ + MUL R14, R12, R0 \ + ADCS R0, R7, R7 \ + MUL R15, R12, R0 \ + ADCS R0, R8, R8 \ + MUL R16, R12, R0 \ + ADCS R0, R9, R9 \ + ADC R10, ZR, R10 \ + UMULH R13, R12, R0 \ + ADDS R0, R7, R6 \ + UMULH R14, R12, R0 \ + ADCS R0, R8, R7 \ + UMULH R15, R12, R0 \ + ADCS R0, R9, R8 \ + UMULH R16, R12, R0 \ + ADCS R0, R10, R9 \ #define MUL_WORD_N() \ - MUL R3, R2, R0 \ - ADDS R0, R11, R11 \ - MUL R11, R16, R17 \ - MUL R4, R2, R0 \ - ADCS R0, R12, R12 \ - MUL R5, R2, R0 \ - ADCS R0, R13, R13 \ - MUL R6, R2, R0 \ - ADCS R0, R14, R14 \ - ADC ZR, ZR, R15 \ - UMULH R3, R2, R0 \ - ADDS R0, R12, R12 \ - UMULH R4, R2, R0 \ - ADCS R0, R13, R13 \ - UMULH R5, R2, R0 \ - ADCS R0, R14, R14 \ - UMULH R6, R2, R0 \ - ADC R0, R15, R15 \ - DIVSHIFT() \ + MUL R2, R1, R0 \ + ADDS R0, R6, R6 \ + MUL R6, R11, R12 \ + MUL R3, R1, R0 \ + ADCS R0, R7, R7 \ + MUL R4, R1, R0 \ + ADCS R0, R8, R8 \ + MUL R5, R1, R0 \ + ADCS R0, R9, R9 \ + ADC ZR, ZR, R10 \ + UMULH R2, R1, R0 \ + ADDS R0, R7, R7 \ + UMULH R3, R1, R0 \ + ADCS R0, R8, R8 \ + UMULH R4, R1, R0 \ + ADCS R0, R9, R9 \ + UMULH R5, R1, R0 \ + ADC R0, R10, R10 \ + DIVSHIFT() \ #define MUL_WORD_0() \ - MUL R3, R2, R11 \ - MUL R4, R2, R12 \ - MUL R5, R2, R13 \ - MUL R6, R2, R14 \ - UMULH R3, R2, R0 \ - ADDS R0, R12, R12 \ - UMULH R4, R2, R0 \ - ADCS R0, R13, R13 \ - UMULH R5, R2, R0 \ - ADCS R0, R14, R14 \ - UMULH R6, R2, R0 \ - ADC R0, ZR, R15 \ - MUL R11, R16, R17 \ - DIVSHIFT() \ + MUL R2, R1, R6 \ + MUL R3, R1, R7 \ + MUL R4, R1, R8 \ + MUL R5, R1, R9 \ + UMULH R2, R1, R0 \ + ADDS R0, R7, R7 \ + UMULH R3, R1, R0 \ + ADCS R0, R8, R8 \ + UMULH R4, R1, R0 \ + ADCS R0, R9, R9 \ + UMULH R5, R1, R0 \ + ADC R0, ZR, R10 \ + MUL R6, R11, R12 \ + DIVSHIFT() \ - MOVD y+16(FP), R1 + MOVD y+16(FP), R17 MOVD x+8(FP), R0 - LDP 0(R0), (R3, R4) - LDP 16(R0), (R5, R6) - MOVD 0(R1), R2 - MOVD $const_qInvNeg, R16 - LDP ·qElement+0(SB), (R7, R8) - LDP ·qElement+16(SB), (R9, R10) + LDP 0(R0), (R2, R3) + LDP 16(R0), (R4, R5) + MOVD 0(R17), R1 + MOVD $const_qInvNeg, R11 + LDP ·qElement+0(SB), (R13, R14) + LDP ·qElement+16(SB), (R15, R16) MUL_WORD_0() - MOVD 8(R1), R2 + MOVD 8(R17), R1 MUL_WORD_N() - MOVD 16(R1), R2 + MOVD 16(R17), R1 MUL_WORD_N() - MOVD 24(R1), R2 + MOVD 24(R17), R1 MUL_WORD_N() // reduce if necessary - SUBS R7, R11, R7 - SBCS R8, R12, R8 - SBCS R9, R13, R9 - SBCS R10, R14, R10 + SUBS R13, R6, R13 + SBCS R14, R7, R14 + SBCS R15, R8, R15 + SBCS R16, R9, R16 MOVD res+0(FP), R0 - CSEL CS, R7, R11, R11 - CSEL CS, R8, R12, R12 - STP (R11, R12), 0(R0) - CSEL CS, R9, R13, R13 - CSEL CS, R10, R14, R14 - STP (R13, R14), 16(R0) + CSEL CS, R13, R6, R6 + CSEL CS, R14, R7, R7 + STP (R6, R7), 0(R0) + CSEL CS, R15, R8, R8 + CSEL CS, R16, R9, R9 + STP (R8, R9), 16(R0) RET // reduce(res *Element) diff --git a/field/asm/element_6w_arm64.s b/field/asm/element_6w_arm64.s index 7b4946b3f9..7dbd7ecaf3 100644 --- a/field/asm/element_6w_arm64.s +++ b/field/asm/element_6w_arm64.s @@ -71,122 +71,122 @@ TEXT ·Butterfly(SB), NOFRAME|NOSPLIT, $0-16 // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 TEXT ·mul(SB), NOFRAME|NOSPLIT, $0-24 #define DIVSHIFT() \ - MUL R9, R24, R0 \ - ADDS R0, R15, R15 \ - MUL R10, R24, R0 \ - ADCS R0, R16, R16 \ - MUL R11, R24, R0 \ - ADCS R0, R17, R17 \ - MUL R12, R24, R0 \ - ADCS R0, R19, R19 \ - MUL R13, R24, R0 \ - ADCS R0, R20, R20 \ - MUL R14, R24, R0 \ - ADCS R0, R21, R21 \ - ADC R22, ZR, R22 \ - UMULH R9, R24, R0 \ - ADDS R0, R16, R15 \ - UMULH R10, R24, R0 \ - ADCS R0, R17, R16 \ - UMULH R11, R24, R0 \ - ADCS R0, R19, R17 \ - UMULH R12, R24, R0 \ - ADCS R0, R20, R19 \ - UMULH R13, R24, R0 \ - ADCS R0, R21, R20 \ - UMULH R14, R24, R0 \ - ADCS R0, R22, R21 \ + MUL R17, R16, R0 \ + ADDS R0, R8, R8 \ + MUL R19, R16, R0 \ + ADCS R0, R9, R9 \ + MUL R20, R16, R0 \ + ADCS R0, R10, R10 \ + MUL R21, R16, R0 \ + ADCS R0, R11, R11 \ + MUL R22, R16, R0 \ + ADCS R0, R12, R12 \ + MUL R23, R16, R0 \ + ADCS R0, R13, R13 \ + ADC R14, ZR, R14 \ + UMULH R17, R16, R0 \ + ADDS R0, R9, R8 \ + UMULH R19, R16, R0 \ + ADCS R0, R10, R9 \ + UMULH R20, R16, R0 \ + ADCS R0, R11, R10 \ + UMULH R21, R16, R0 \ + ADCS R0, R12, R11 \ + UMULH R22, R16, R0 \ + ADCS R0, R13, R12 \ + UMULH R23, R16, R0 \ + ADCS R0, R14, R13 \ #define MUL_WORD_N() \ - MUL R3, R2, R0 \ - ADDS R0, R15, R15 \ - MUL R15, R23, R24 \ - MUL R4, R2, R0 \ - ADCS R0, R16, R16 \ - MUL R5, R2, R0 \ - ADCS R0, R17, R17 \ - MUL R6, R2, R0 \ - ADCS R0, R19, R19 \ - MUL R7, R2, R0 \ - ADCS R0, R20, R20 \ - MUL R8, R2, R0 \ - ADCS R0, R21, R21 \ - ADC ZR, ZR, R22 \ - UMULH R3, R2, R0 \ - ADDS R0, R16, R16 \ - UMULH R4, R2, R0 \ - ADCS R0, R17, R17 \ - UMULH R5, R2, R0 \ - ADCS R0, R19, R19 \ - UMULH R6, R2, R0 \ - ADCS R0, R20, R20 \ - UMULH R7, R2, R0 \ - ADCS R0, R21, R21 \ - UMULH R8, R2, R0 \ - ADC R0, R22, R22 \ - DIVSHIFT() \ + MUL R2, R1, R0 \ + ADDS R0, R8, R8 \ + MUL R8, R15, R16 \ + MUL R3, R1, R0 \ + ADCS R0, R9, R9 \ + MUL R4, R1, R0 \ + ADCS R0, R10, R10 \ + MUL R5, R1, R0 \ + ADCS R0, R11, R11 \ + MUL R6, R1, R0 \ + ADCS R0, R12, R12 \ + MUL R7, R1, R0 \ + ADCS R0, R13, R13 \ + ADC ZR, ZR, R14 \ + UMULH R2, R1, R0 \ + ADDS R0, R9, R9 \ + UMULH R3, R1, R0 \ + ADCS R0, R10, R10 \ + UMULH R4, R1, R0 \ + ADCS R0, R11, R11 \ + UMULH R5, R1, R0 \ + ADCS R0, R12, R12 \ + UMULH R6, R1, R0 \ + ADCS R0, R13, R13 \ + UMULH R7, R1, R0 \ + ADC R0, R14, R14 \ + DIVSHIFT() \ #define MUL_WORD_0() \ - MUL R3, R2, R15 \ - MUL R4, R2, R16 \ - MUL R5, R2, R17 \ - MUL R6, R2, R19 \ - MUL R7, R2, R20 \ - MUL R8, R2, R21 \ - UMULH R3, R2, R0 \ - ADDS R0, R16, R16 \ - UMULH R4, R2, R0 \ - ADCS R0, R17, R17 \ - UMULH R5, R2, R0 \ - ADCS R0, R19, R19 \ - UMULH R6, R2, R0 \ - ADCS R0, R20, R20 \ - UMULH R7, R2, R0 \ - ADCS R0, R21, R21 \ - UMULH R8, R2, R0 \ - ADC R0, ZR, R22 \ - MUL R15, R23, R24 \ - DIVSHIFT() \ + MUL R2, R1, R8 \ + MUL R3, R1, R9 \ + MUL R4, R1, R10 \ + MUL R5, R1, R11 \ + MUL R6, R1, R12 \ + MUL R7, R1, R13 \ + UMULH R2, R1, R0 \ + ADDS R0, R9, R9 \ + UMULH R3, R1, R0 \ + ADCS R0, R10, R10 \ + UMULH R4, R1, R0 \ + ADCS R0, R11, R11 \ + UMULH R5, R1, R0 \ + ADCS R0, R12, R12 \ + UMULH R6, R1, R0 \ + ADCS R0, R13, R13 \ + UMULH R7, R1, R0 \ + ADC R0, ZR, R14 \ + MUL R8, R15, R16 \ + DIVSHIFT() \ - MOVD y+16(FP), R1 + MOVD y+16(FP), R24 MOVD x+8(FP), R0 - LDP 0(R0), (R3, R4) - LDP 16(R0), (R5, R6) - LDP 32(R0), (R7, R8) - MOVD 0(R1), R2 - MOVD $const_qInvNeg, R23 - LDP ·qElement+0(SB), (R9, R10) - LDP ·qElement+16(SB), (R11, R12) - LDP ·qElement+32(SB), (R13, R14) + LDP 0(R0), (R2, R3) + LDP 16(R0), (R4, R5) + LDP 32(R0), (R6, R7) + MOVD 0(R24), R1 + MOVD $const_qInvNeg, R15 + LDP ·qElement+0(SB), (R17, R19) + LDP ·qElement+16(SB), (R20, R21) + LDP ·qElement+32(SB), (R22, R23) MUL_WORD_0() - MOVD 8(R1), R2 + MOVD 8(R24), R1 MUL_WORD_N() - MOVD 16(R1), R2 + MOVD 16(R24), R1 MUL_WORD_N() - MOVD 24(R1), R2 + MOVD 24(R24), R1 MUL_WORD_N() - MOVD 32(R1), R2 + MOVD 32(R24), R1 MUL_WORD_N() - MOVD 40(R1), R2 + MOVD 40(R24), R1 MUL_WORD_N() // reduce if necessary - SUBS R9, R15, R9 - SBCS R10, R16, R10 - SBCS R11, R17, R11 - SBCS R12, R19, R12 - SBCS R13, R20, R13 - SBCS R14, R21, R14 + SUBS R17, R8, R17 + SBCS R19, R9, R19 + SBCS R20, R10, R20 + SBCS R21, R11, R21 + SBCS R22, R12, R22 + SBCS R23, R13, R23 MOVD res+0(FP), R0 - CSEL CS, R9, R15, R15 - CSEL CS, R10, R16, R16 - STP (R15, R16), 0(R0) - CSEL CS, R11, R17, R17 - CSEL CS, R12, R19, R19 - STP (R17, R19), 16(R0) - CSEL CS, R13, R20, R20 - CSEL CS, R14, R21, R21 - STP (R20, R21), 32(R0) + CSEL CS, R17, R8, R8 + CSEL CS, R19, R9, R9 + STP (R8, R9), 0(R0) + CSEL CS, R20, R10, R10 + CSEL CS, R21, R11, R11 + STP (R10, R11), 16(R0) + CSEL CS, R22, R12, R12 + CSEL CS, R23, R13, R13 + STP (R12, R13), 32(R0) RET // reduce(res *Element) diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index cba7bde8bc..08bc52fb1c 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -107,7 +107,9 @@ func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { panic("NbWords must be even") } - f.generateButterfly() + if f.NbWords <= 6 { + f.generateButterfly() + } f.generateMul() f.generateReduce() diff --git a/field/generator/asm/arm64/element_ops.go b/field/generator/asm/arm64/element_ops.go deleted file mode 100644 index 354cd68d99..0000000000 --- a/field/generator/asm/arm64/element_ops.go +++ /dev/null @@ -1,267 +0,0 @@ -// Copyright 2022 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package arm64 - -import ( - "github.com/consensys/bavard/arm64" -) - -func (f *FFArm64) generateButterfly() { - f.Comment("butterfly(a, b *Element)") - f.Comment("a, b = a+b, a-b") - registers := f.FnHeader("Butterfly", 0, 16) - defer f.AssertCleanStack(0, 0) - - // registers - a := registers.PopN(f.NbWords) - b := registers.PopN(f.NbWords) - r := registers.PopN(f.NbWords) - t := registers.PopN(f.NbWords) - aPtr := registers.Pop() - bPtr := registers.Pop() - - f.LDP("x+0(FP)", aPtr, bPtr) - f.load(aPtr, a) - f.load(bPtr, b) - - for i := 0; i < f.NbWords; i++ { - f.add0n(i)(a[i], b[i], r[i]) - } - - f.SUBS(b[0], a[0], b[0]) - for i := 1; i < f.NbWords; i++ { - f.SBCS(b[i], a[i], b[i]) - } - - for i := 0; i < f.NbWords; i++ { - if i%2 == 0 { - f.LDP(f.qAt(i), a[i], a[i+1]) - } - f.CSEL("CS", "ZR", a[i], t[i]) - } - f.Comment("add q if underflow, 0 if not") - for i := 0; i < f.NbWords; i++ { - f.add0n(i)(b[i], t[i], b[i]) - if i%2 == 1 { - f.STP(b[i-1], b[i], bPtr.At(i-1)) - } - } - - f.reduceAndStore(r, a, aPtr) - - f.RET() -} - -func (f *FFArm64) generateReduce() { - f.Comment("reduce(res *Element)") - registers := f.FnHeader("reduce", 0, 8) - defer f.AssertCleanStack(0, 0) - - // registers - t := registers.PopN(f.NbWords) - q := registers.PopN(f.NbWords) - rPtr := registers.Pop() - - for i := 0; i < f.NbWords; i += 2 { - f.LDP(f.qAt(i), q[i], q[i+1]) - } - - f.MOVD("res+0(FP)", rPtr) - f.load(rPtr, t) - f.reduceAndStore(t, q, rPtr) - - f.RET() -} - -func (f *FFArm64) generateMul() { - f.Comment("mul(res, x, y *Element)") - f.Comment("Algorithm 2 of Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS") - f.Comment("by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521") - registers := f.FnHeader("mul", 0, 24) - defer f.AssertCleanStack(0, 0) - - xPtr := registers.Pop() - yPtr := registers.Pop() - bi := registers.Pop() - a := registers.PopN(f.NbWords) - q := registers.PopN(f.NbWords) - t := registers.PopN(f.NbWords + 1) - - ax := xPtr - qInv0 := registers.Pop() - m := registers.Pop() - - divShift := f.Define("divShift", 0, func(args ...arm64.Register) { - // for j=0 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - - for j := 0; j < f.NbWords; j++ { - f.MUL(q[j], m, ax) - f.add0m(j)(ax, t[j], t[j]) - } - f.add0m(f.NbWords)(t[f.NbWords], "ZR", t[f.NbWords]) - - // propagate high bits - f.UMULH(q[0], m, ax) - for j := 1; j <= f.NbWords; j++ { - f.add1m(j, true)(ax, t[j], t[j-1]) - if j != f.NbWords { - f.UMULH(q[j], m, ax) - } - } - }) - - mulWordN := f.Define("MUL_WORD_N", 0, func(args ...arm64.Register) { - // for j=0 to N-1 - // (C,t[j]) := t[j] + a[j]*b[i] + C - - // lo bits - for j := 0; j < f.NbWords; j++ { - f.MUL(a[j], bi, ax) - f.add0m(j)(ax, t[j], t[j]) - - if j == 0 { - f.MUL(t[0], qInv0, m) - } - } - f.add0m(f.NbWords)("ZR", "ZR", t[f.NbWords]) - - // propagate high bits - f.UMULH(a[0], bi, ax) - for j := 1; j <= f.NbWords; j++ { - f.add1m(j)(ax, t[j], t[j]) - if j != f.NbWords { - f.UMULH(a[j], bi, ax) - } - } - divShift() - }) - - mulWord0 := f.Define("MUL_WORD_0", 0, func(args ...arm64.Register) { - // for j=0 to N-1 - // (C,t[j]) := t[j] + a[j]*b[i] + C - // lo bits - for j := 0; j < f.NbWords; j++ { - f.MUL(a[j], bi, t[j]) - } - - // propagate high bits - f.UMULH(a[0], bi, ax) - for j := 1; j < f.NbWords; j++ { - f.add1m(j)(ax, t[j], t[j]) - f.UMULH(a[j], bi, ax) - } - f.add1m(f.NbWords)(ax, "ZR", t[f.NbWords]) - f.MUL(t[0], qInv0, m) - divShift() - }) - - f.MOVD("y+16(FP)", yPtr) - f.MOVD("x+8(FP)", xPtr) - f.load(xPtr, a) - for i := 0; i < f.NbWords; i++ { - f.MOVD(yPtr.At(i), bi) - - if i == 0 { - // load qInv0 and q at first iteration. - f.MOVD(f.qInv0(), qInv0) - for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(f.qAt(i), q[i], q[i+1]) - } - mulWord0() - } else { - mulWordN() - } - } - - f.Comment("reduce if necessary") - f.SUBS(q[0], t[0], q[0]) - for i := 1; i < f.NbWords; i++ { - f.SBCS(q[i], t[i], q[i]) - } - - f.MOVD("res+0(FP)", ax) - for i := 0; i < f.NbWords; i++ { - f.CSEL("CS", q[i], t[i], t[i]) - if i%2 == 1 { - f.STP(t[i-1], t[i], ax.At(i-1)) - } - } - - f.RET() -} - -func (f *FFArm64) load(zPtr arm64.Register, z []arm64.Register) { - for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(zPtr.At(i), z[i], z[i+1]) - } -} - -// q must contain the modulus -// q is modified -// t = t mod q (t must be less than 2q) -// t is stored in zPtr -func (f *FFArm64) reduceAndStore(t, q []arm64.Register, zPtr arm64.Register) { - f.Comment("q = t - q") - f.SUBS(q[0], t[0], q[0]) - for i := 1; i < f.NbWords; i++ { - f.SBCS(q[i], t[i], q[i]) - } - - f.Comment("if no borrow, return q, else return t") - for i := 0; i < f.NbWords; i++ { - f.CSEL("CS", q[i], t[i], t[i]) - if i%2 == 1 { - f.STP(t[i-1], t[i], zPtr.At(i-1)) - } - } -} - -func (f *FFArm64) add0n(i int) func(op1, op2, dst interface{}, comment ...string) { - switch { - case i == 0: - return f.ADDS - case i == f.NbWordsLastIndex: - return f.ADC - default: - return f.ADCS - } -} - -func (f *FFArm64) add0m(i int) func(op1, op2, dst interface{}, comment ...string) { - switch { - case i == 0: - return f.ADDS - case i == f.NbWordsLastIndex+1: - return f.ADC - default: - return f.ADCS - } -} - -func (f *FFArm64) add1m(i int, dumb ...bool) func(op1, op2, dst interface{}, comment ...string) { - switch { - case i == 1: - return f.ADDS - case i == f.NbWordsLastIndex+1: - if len(dumb) == 1 && dumb[0] { - // odd, but it performs better on c8g instances. - return f.ADCS - } - return f.ADC - default: - return f.ADCS - } -} diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index 1efe7d8b32..f6e5513b4c 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -267,7 +267,7 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) // asm code generation for moduli with more than 6 words can be optimized further F.GenerateOpsAMD64 = F.NoCarry && F.NbWords <= 12 && F.NbWords > 1 F.GenerateVectorOpsAMD64 = F.GenerateOpsAMD64 && F.NbWords == 4 && F.NbBits > 225 - F.GenerateOpsARM64 = F.GenerateOpsAMD64 && (F.NbWords == 6 || F.NbWords == 4) + F.GenerateOpsARM64 = F.GenerateOpsAMD64 && (F.NbWords%2 == 0) F.GenerateVectorOpsARM64 = false // setting Mu 2^288 / q diff --git a/field/generator/generator_test.go b/field/generator/generator_test.go index e490baa03d..ee5b5fafc0 100644 --- a/field/generator/generator_test.go +++ b/field/generator/generator_test.go @@ -84,6 +84,9 @@ func TestIntegration(t *testing.T) { assert.NoError(GenerateAMD64(7, asmDir, false)) assert.NoError(GenerateAMD64(8, asmDir, false)) + assert.NoError(GenerateARM64(2, asmDir, false)) + assert.NoError(GenerateARM64(8, asmDir, false)) + for elementName, modulus := range moduli { var fIntegration *field.FieldConfig // generate field diff --git a/field/generator/internal/templates/element/ops_asm.go b/field/generator/internal/templates/element/ops_asm.go index 271d7d1165..52b2c0b2b7 100644 --- a/field/generator/internal/templates/element/ops_asm.go +++ b/field/generator/internal/templates/element/ops_asm.go @@ -50,8 +50,14 @@ const OpsARM64 = ` // Butterfly sets // a = a + b (mod q) // b = a - b (mod q) +{{- if le .NbWords 6}} //go:noescape func Butterfly(a, b *{{.ElementName}}) +{{- else}} +func Butterfly(a, b *{{.ElementName}}) { + _butterflyGeneric(a, b) +} +{{- end}} //go:noescape func mul(res,x,y *{{.ElementName}}) diff --git a/go.mod b/go.mod index 4297d50546..c4dbc8bc84 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/bits-and-blooms/bitset v1.14.2 - github.com/consensys/bavard v0.1.23-0.20241021201139-ab3fee069cde + github.com/consensys/bavard v0.1.23-0.20241022191302-a6fdcdb6e8f3 github.com/leanovate/gopter v0.2.11 github.com/mmcloughlin/addchain v0.4.0 github.com/spf13/cobra v1.8.1 diff --git a/go.sum b/go.sum index 40ca76c748..af73f869f3 100644 --- a/go.sum +++ b/go.sum @@ -61,6 +61,10 @@ github.com/consensys/bavard v0.1.23-0.20241019150039-28659c2eb91c h1:sK5i7h6ZVAj github.com/consensys/bavard v0.1.23-0.20241019150039-28659c2eb91c/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/bavard v0.1.23-0.20241021201139-ab3fee069cde h1:KXywceL5kuPe9PAQHHBvt4Kki7/XqsW7ABJI9dn4zik= github.com/consensys/bavard v0.1.23-0.20241021201139-ab3fee069cde/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= +github.com/consensys/bavard v0.1.23-0.20241022191117-d73e50a886cc h1:NwWCvGXSPH8BYATHBdy7qTJ3NMoT1kWVAvuEPtvasqg= +github.com/consensys/bavard v0.1.23-0.20241022191117-d73e50a886cc/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= +github.com/consensys/bavard v0.1.23-0.20241022191302-a6fdcdb6e8f3 h1:8gPxbjhwhxXTakOXII32eLlAFLlYImoENa3uQ6iP+go= +github.com/consensys/bavard v0.1.23-0.20241022191302-a6fdcdb6e8f3/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= diff --git a/internal/generator/main.go b/internal/generator/main.go index 45fbb4fe20..c17605a414 100644 --- a/internal/generator/main.go +++ b/internal/generator/main.go @@ -60,6 +60,8 @@ func main() { assertNoError(generator.GenerateARM64(4, asmDirBuildPath, false)) assertNoError(generator.GenerateARM64(6, asmDirBuildPath, false)) + assertNoError(generator.GenerateARM64(10, asmDirBuildPath, false)) + assertNoError(generator.GenerateARM64(12, asmDirBuildPath, false)) var wg sync.WaitGroup for _, conf := range config.Curves { From 3638f44eac6b33ee1fba4565b9c4e52f1b99a46e Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 22 Oct 2024 19:45:10 +0000 Subject: [PATCH 23/74] feat: add missing files --- ecc/bw6-633/fp/element_arm64.go | 83 +++++ ecc/bw6-633/fp/element_arm64.s | 21 ++ ecc/bw6-761/fp/element_arm64.go | 85 +++++ ecc/bw6-761/fp/element_arm64.s | 21 ++ field/asm/element_10w_arm64.s | 266 +++++++++++++++ field/asm/element_12w_arm64.s | 312 ++++++++++++++++++ .../generator/asm/arm64/element_butterfly.go | 47 +++ field/generator/asm/arm64/element_mul.go | 259 +++++++++++++++ 8 files changed, 1094 insertions(+) create mode 100644 ecc/bw6-633/fp/element_arm64.go create mode 100644 ecc/bw6-633/fp/element_arm64.s create mode 100644 ecc/bw6-761/fp/element_arm64.go create mode 100644 ecc/bw6-761/fp/element_arm64.s create mode 100644 field/asm/element_10w_arm64.s create mode 100644 field/asm/element_12w_arm64.s create mode 100644 field/generator/asm/arm64/element_butterfly.go create mode 100644 field/generator/asm/arm64/element_mul.go diff --git a/ecc/bw6-633/fp/element_arm64.go b/ecc/bw6-633/fp/element_arm64.go new file mode 100644 index 0000000000..aeabec816c --- /dev/null +++ b/ecc/bw6-633/fp/element_arm64.go @@ -0,0 +1,83 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +//go:noescape +func mul(res, x, y *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 4881606927653498122, + 47978232019095094, + 8555661377410121478, + 17849732488791568215, + 5227097555314997552, + 839611732066804726, + 5234648925333584632, + 11936054402769696488, + 1228498468693814883, + 2857848702739380, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +//go:noescape +func reduce(res *Element) diff --git a/ecc/bw6-633/fp/element_arm64.s b/ecc/bw6-633/fp/element_arm64.s new file mode 100644 index 0000000000..19ca3bc382 --- /dev/null +++ b/ecc/bw6-633/fp/element_arm64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 4283725514119985738 +#include "../../../field/asm/element_10w_arm64.s" + diff --git a/ecc/bw6-761/fp/element_arm64.go b/ecc/bw6-761/fp/element_arm64.go new file mode 100644 index 0000000000..df5acdc99a --- /dev/null +++ b/ecc/bw6-761/fp/element_arm64.go @@ -0,0 +1,85 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +//go:noescape +func mul(res, x, y *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 4345973640412121648, + 16340807117537158706, + 14673764841507373218, + 5587754667198343811, + 12846753860245084942, + 4041391838244625385, + 8324122986343791677, + 8773809490091176420, + 5465994123296109449, + 6649773564661156048, + 9147430723089113754, + 54281803719730243, + } + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +//go:noescape +func reduce(res *Element) diff --git a/ecc/bw6-761/fp/element_arm64.s b/ecc/bw6-761/fp/element_arm64.s new file mode 100644 index 0000000000..9ed6049e69 --- /dev/null +++ b/ecc/bw6-761/fp/element_arm64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 17465962485072383759 +#include "../../../field/asm/element_12w_arm64.s" + diff --git a/field/asm/element_10w_arm64.s b/field/asm/element_10w_arm64.s new file mode 100644 index 0000000000..529fcae61d --- /dev/null +++ b/field/asm/element_10w_arm64.s @@ -0,0 +1,266 @@ +// Code generated by gnark-crypto/generator. DO NOT EDIT. +#include "textflag.h" +#include "funcdata.h" +#include "go_asm.h" + +// mul(res, x, y *Element) +// Algorithm 2 of Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS +// by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 +TEXT ·mul(SB), NOFRAME|NOSPLIT, $0-24 +#define DIVSHIFT() \ + MOVD $const_qInvNeg, R0 \ + MUL R12, R0, R1 \ + MOVD ·qElement+0(SB), R0 \ + MUL R0, R1, R0 \ + ADDS R0, R12, R12 \ + MOVD ·qElement+8(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R13, R13 \ + MOVD ·qElement+16(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R14, R14 \ + MOVD ·qElement+24(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R15, R15 \ + MOVD ·qElement+32(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R16, R16 \ + MOVD ·qElement+40(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R17, R17 \ + MOVD ·qElement+48(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R19, R19 \ + MOVD ·qElement+56(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R20, R20 \ + MOVD ·qElement+64(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R21, R21 \ + MOVD ·qElement+72(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R22, R22 \ + ADC R23, ZR, R23 \ + MOVD ·qElement+0(SB), R0 \ + UMULH R0, R1, R0 \ + ADDS R0, R13, R12 \ + MOVD ·qElement+8(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R14, R13 \ + MOVD ·qElement+16(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R15, R14 \ + MOVD ·qElement+24(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R16, R15 \ + MOVD ·qElement+32(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R17, R16 \ + MOVD ·qElement+40(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R19, R17 \ + MOVD ·qElement+48(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R20, R19 \ + MOVD ·qElement+56(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R21, R20 \ + MOVD ·qElement+64(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R22, R21 \ + MOVD ·qElement+72(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R23, R22 \ + +#define MUL_WORD_N() \ + MUL R2, R1, R0 \ + ADDS R0, R12, R12 \ + MUL R3, R1, R0 \ + ADCS R0, R13, R13 \ + MUL R4, R1, R0 \ + ADCS R0, R14, R14 \ + MUL R5, R1, R0 \ + ADCS R0, R15, R15 \ + MUL R6, R1, R0 \ + ADCS R0, R16, R16 \ + MUL R7, R1, R0 \ + ADCS R0, R17, R17 \ + MUL R8, R1, R0 \ + ADCS R0, R19, R19 \ + MUL R9, R1, R0 \ + ADCS R0, R20, R20 \ + MUL R10, R1, R0 \ + ADCS R0, R21, R21 \ + MUL R11, R1, R0 \ + ADCS R0, R22, R22 \ + ADC ZR, ZR, R23 \ + UMULH R2, R1, R0 \ + ADDS R0, R13, R13 \ + UMULH R3, R1, R0 \ + ADCS R0, R14, R14 \ + UMULH R4, R1, R0 \ + ADCS R0, R15, R15 \ + UMULH R5, R1, R0 \ + ADCS R0, R16, R16 \ + UMULH R6, R1, R0 \ + ADCS R0, R17, R17 \ + UMULH R7, R1, R0 \ + ADCS R0, R19, R19 \ + UMULH R8, R1, R0 \ + ADCS R0, R20, R20 \ + UMULH R9, R1, R0 \ + ADCS R0, R21, R21 \ + UMULH R10, R1, R0 \ + ADCS R0, R22, R22 \ + UMULH R11, R1, R0 \ + ADC R0, R23, R23 \ + DIVSHIFT() \ + +#define MUL_WORD_0() \ + MUL R2, R1, R12 \ + MUL R3, R1, R13 \ + MUL R4, R1, R14 \ + MUL R5, R1, R15 \ + MUL R6, R1, R16 \ + MUL R7, R1, R17 \ + MUL R8, R1, R19 \ + MUL R9, R1, R20 \ + MUL R10, R1, R21 \ + MUL R11, R1, R22 \ + UMULH R2, R1, R0 \ + ADDS R0, R13, R13 \ + UMULH R3, R1, R0 \ + ADCS R0, R14, R14 \ + UMULH R4, R1, R0 \ + ADCS R0, R15, R15 \ + UMULH R5, R1, R0 \ + ADCS R0, R16, R16 \ + UMULH R6, R1, R0 \ + ADCS R0, R17, R17 \ + UMULH R7, R1, R0 \ + ADCS R0, R19, R19 \ + UMULH R8, R1, R0 \ + ADCS R0, R20, R20 \ + UMULH R9, R1, R0 \ + ADCS R0, R21, R21 \ + UMULH R10, R1, R0 \ + ADCS R0, R22, R22 \ + UMULH R11, R1, R0 \ + ADC R0, ZR, R23 \ + DIVSHIFT() \ + + MOVD y+16(FP), R1 + MOVD x+8(FP), R0 + LDP 0(R0), (R2, R3) + LDP 16(R0), (R4, R5) + LDP 32(R0), (R6, R7) + LDP 48(R0), (R8, R9) + LDP 64(R0), (R10, R11) + MOVD y+16(FP), R1 + MOVD 0(R1), R1 + MUL_WORD_0() + MOVD y+16(FP), R1 + MOVD 8(R1), R1 + MUL_WORD_N() + MOVD y+16(FP), R1 + MOVD 16(R1), R1 + MUL_WORD_N() + MOVD y+16(FP), R1 + MOVD 24(R1), R1 + MUL_WORD_N() + MOVD y+16(FP), R1 + MOVD 32(R1), R1 + MUL_WORD_N() + MOVD y+16(FP), R1 + MOVD 40(R1), R1 + MUL_WORD_N() + MOVD y+16(FP), R1 + MOVD 48(R1), R1 + MUL_WORD_N() + MOVD y+16(FP), R1 + MOVD 56(R1), R1 + MUL_WORD_N() + MOVD y+16(FP), R1 + MOVD 64(R1), R1 + MUL_WORD_N() + MOVD y+16(FP), R1 + MOVD 72(R1), R1 + MUL_WORD_N() + LDP ·qElement+0(SB), (R2, R3) + LDP ·qElement+16(SB), (R4, R5) + LDP ·qElement+32(SB), (R6, R7) + LDP ·qElement+48(SB), (R8, R9) + LDP ·qElement+64(SB), (R10, R11) + + // reduce if necessary + SUBS R2, R12, R2 + SBCS R3, R13, R3 + SBCS R4, R14, R4 + SBCS R5, R15, R5 + SBCS R6, R16, R6 + SBCS R7, R17, R7 + SBCS R8, R19, R8 + SBCS R9, R20, R9 + SBCS R10, R21, R10 + SBCS R11, R22, R11 + MOVD res+0(FP), R0 + CSEL CS, R2, R12, R12 + CSEL CS, R3, R13, R13 + STP (R12, R13), 0(R0) + CSEL CS, R4, R14, R14 + CSEL CS, R5, R15, R15 + STP (R14, R15), 16(R0) + CSEL CS, R6, R16, R16 + CSEL CS, R7, R17, R17 + STP (R16, R17), 32(R0) + CSEL CS, R8, R19, R19 + CSEL CS, R9, R20, R20 + STP (R19, R20), 48(R0) + CSEL CS, R10, R21, R21 + CSEL CS, R11, R22, R22 + STP (R21, R22), 64(R0) + RET + +// reduce(res *Element) +TEXT ·reduce(SB), NOFRAME|NOSPLIT, $0-8 + LDP ·qElement+0(SB), (R10, R11) + LDP ·qElement+16(SB), (R12, R13) + LDP ·qElement+32(SB), (R14, R15) + LDP ·qElement+48(SB), (R16, R17) + LDP ·qElement+64(SB), (R19, R20) + MOVD res+0(FP), R21 + LDP 0(R21), (R0, R1) + LDP 16(R21), (R2, R3) + LDP 32(R21), (R4, R5) + LDP 48(R21), (R6, R7) + LDP 64(R21), (R8, R9) + + // q = t - q + SUBS R10, R0, R10 + SBCS R11, R1, R11 + SBCS R12, R2, R12 + SBCS R13, R3, R13 + SBCS R14, R4, R14 + SBCS R15, R5, R15 + SBCS R16, R6, R16 + SBCS R17, R7, R17 + SBCS R19, R8, R19 + SBCS R20, R9, R20 + + // if no borrow, return q, else return t + CSEL CS, R10, R0, R0 + CSEL CS, R11, R1, R1 + STP (R0, R1), 0(R21) + CSEL CS, R12, R2, R2 + CSEL CS, R13, R3, R3 + STP (R2, R3), 16(R21) + CSEL CS, R14, R4, R4 + CSEL CS, R15, R5, R5 + STP (R4, R5), 32(R21) + CSEL CS, R16, R6, R6 + CSEL CS, R17, R7, R7 + STP (R6, R7), 48(R21) + CSEL CS, R19, R8, R8 + CSEL CS, R20, R9, R9 + STP (R8, R9), 64(R21) + RET diff --git a/field/asm/element_12w_arm64.s b/field/asm/element_12w_arm64.s new file mode 100644 index 0000000000..a03e790ae6 --- /dev/null +++ b/field/asm/element_12w_arm64.s @@ -0,0 +1,312 @@ +// Code generated by gnark-crypto/generator. DO NOT EDIT. +#include "textflag.h" +#include "funcdata.h" +#include "go_asm.h" + +// mul(res, x, y *Element) +// Algorithm 2 of Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS +// by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 +TEXT ·mul(SB), NOFRAME|NOSPLIT, $0-24 +#define DIVSHIFT() \ + MOVD $const_qInvNeg, R0 \ + MUL R14, R0, R1 \ + MOVD ·qElement+0(SB), R0 \ + MUL R0, R1, R0 \ + ADDS R0, R14, R14 \ + MOVD ·qElement+8(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R15, R15 \ + MOVD ·qElement+16(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R16, R16 \ + MOVD ·qElement+24(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R17, R17 \ + MOVD ·qElement+32(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R19, R19 \ + MOVD ·qElement+40(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R20, R20 \ + MOVD ·qElement+48(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R21, R21 \ + MOVD ·qElement+56(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R22, R22 \ + MOVD ·qElement+64(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R23, R23 \ + MOVD ·qElement+72(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R24, R24 \ + MOVD ·qElement+80(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R25, R25 \ + MOVD ·qElement+88(SB), R0 \ + MUL R0, R1, R0 \ + ADCS R0, R26, R26 \ + ADC R29, ZR, R29 \ + MOVD ·qElement+0(SB), R0 \ + UMULH R0, R1, R0 \ + ADDS R0, R15, R14 \ + MOVD ·qElement+8(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R16, R15 \ + MOVD ·qElement+16(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R17, R16 \ + MOVD ·qElement+24(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R19, R17 \ + MOVD ·qElement+32(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R20, R19 \ + MOVD ·qElement+40(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R21, R20 \ + MOVD ·qElement+48(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R22, R21 \ + MOVD ·qElement+56(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R23, R22 \ + MOVD ·qElement+64(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R24, R23 \ + MOVD ·qElement+72(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R25, R24 \ + MOVD ·qElement+80(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R26, R25 \ + MOVD ·qElement+88(SB), R0 \ + UMULH R0, R1, R0 \ + ADCS R0, R29, R26 \ + +#define MUL_WORD_N() \ + MUL R2, R1, R0 \ + ADDS R0, R14, R14 \ + MUL R3, R1, R0 \ + ADCS R0, R15, R15 \ + MUL R4, R1, R0 \ + ADCS R0, R16, R16 \ + MUL R5, R1, R0 \ + ADCS R0, R17, R17 \ + MUL R6, R1, R0 \ + ADCS R0, R19, R19 \ + MUL R7, R1, R0 \ + ADCS R0, R20, R20 \ + MUL R8, R1, R0 \ + ADCS R0, R21, R21 \ + MUL R9, R1, R0 \ + ADCS R0, R22, R22 \ + MUL R10, R1, R0 \ + ADCS R0, R23, R23 \ + MUL R11, R1, R0 \ + ADCS R0, R24, R24 \ + MUL R12, R1, R0 \ + ADCS R0, R25, R25 \ + MUL R13, R1, R0 \ + ADCS R0, R26, R26 \ + ADC ZR, ZR, R29 \ + UMULH R2, R1, R0 \ + ADDS R0, R15, R15 \ + UMULH R3, R1, R0 \ + ADCS R0, R16, R16 \ + UMULH R4, R1, R0 \ + ADCS R0, R17, R17 \ + UMULH R5, R1, R0 \ + ADCS R0, R19, R19 \ + UMULH R6, R1, R0 \ + ADCS R0, R20, R20 \ + UMULH R7, R1, R0 \ + ADCS R0, R21, R21 \ + UMULH R8, R1, R0 \ + ADCS R0, R22, R22 \ + UMULH R9, R1, R0 \ + ADCS R0, R23, R23 \ + UMULH R10, R1, R0 \ + ADCS R0, R24, R24 \ + UMULH R11, R1, R0 \ + ADCS R0, R25, R25 \ + UMULH R12, R1, R0 \ + ADCS R0, R26, R26 \ + UMULH R13, R1, R0 \ + ADC R0, R29, R29 \ + DIVSHIFT() \ + +#define MUL_WORD_0() \ + MUL R2, R1, R14 \ + MUL R3, R1, R15 \ + MUL R4, R1, R16 \ + MUL R5, R1, R17 \ + MUL R6, R1, R19 \ + MUL R7, R1, R20 \ + MUL R8, R1, R21 \ + MUL R9, R1, R22 \ + MUL R10, R1, R23 \ + MUL R11, R1, R24 \ + MUL R12, R1, R25 \ + MUL R13, R1, R26 \ + UMULH R2, R1, R0 \ + ADDS R0, R15, R15 \ + UMULH R3, R1, R0 \ + ADCS R0, R16, R16 \ + UMULH R4, R1, R0 \ + ADCS R0, R17, R17 \ + UMULH R5, R1, R0 \ + ADCS R0, R19, R19 \ + UMULH R6, R1, R0 \ + ADCS R0, R20, R20 \ + UMULH R7, R1, R0 \ + ADCS R0, R21, R21 \ + UMULH R8, R1, R0 \ + ADCS R0, R22, R22 \ + UMULH R9, R1, R0 \ + ADCS R0, R23, R23 \ + UMULH R10, R1, R0 \ + ADCS R0, R24, R24 \ + UMULH R11, R1, R0 \ + ADCS R0, R25, R25 \ + UMULH R12, R1, R0 \ + ADCS R0, R26, R26 \ + UMULH R13, R1, R0 \ + ADC R0, ZR, R29 \ + DIVSHIFT() \ + + MOVD y+16(FP), R1 + MOVD x+8(FP), R0 + LDP 0(R0), (R2, R3) + LDP 16(R0), (R4, R5) + LDP 32(R0), (R6, R7) + LDP 48(R0), (R8, R9) + LDP 64(R0), (R10, R11) + LDP 80(R0), (R12, R13) + MOVD y+16(FP), R1 + MOVD 0(R1), R1 + MUL_WORD_0() + MOVD y+16(FP), R1 + MOVD 8(R1), R1 + MUL_WORD_N() + MOVD y+16(FP), R1 + MOVD 16(R1), R1 + MUL_WORD_N() + MOVD y+16(FP), R1 + MOVD 24(R1), R1 + MUL_WORD_N() + MOVD y+16(FP), R1 + MOVD 32(R1), R1 + MUL_WORD_N() + MOVD y+16(FP), R1 + MOVD 40(R1), R1 + MUL_WORD_N() + MOVD y+16(FP), R1 + MOVD 48(R1), R1 + MUL_WORD_N() + MOVD y+16(FP), R1 + MOVD 56(R1), R1 + MUL_WORD_N() + MOVD y+16(FP), R1 + MOVD 64(R1), R1 + MUL_WORD_N() + MOVD y+16(FP), R1 + MOVD 72(R1), R1 + MUL_WORD_N() + MOVD y+16(FP), R1 + MOVD 80(R1), R1 + MUL_WORD_N() + MOVD y+16(FP), R1 + MOVD 88(R1), R1 + MUL_WORD_N() + LDP ·qElement+0(SB), (R2, R3) + LDP ·qElement+16(SB), (R4, R5) + LDP ·qElement+32(SB), (R6, R7) + LDP ·qElement+48(SB), (R8, R9) + LDP ·qElement+64(SB), (R10, R11) + LDP ·qElement+80(SB), (R12, R13) + + // reduce if necessary + SUBS R2, R14, R2 + SBCS R3, R15, R3 + SBCS R4, R16, R4 + SBCS R5, R17, R5 + SBCS R6, R19, R6 + SBCS R7, R20, R7 + SBCS R8, R21, R8 + SBCS R9, R22, R9 + SBCS R10, R23, R10 + SBCS R11, R24, R11 + SBCS R12, R25, R12 + SBCS R13, R26, R13 + MOVD res+0(FP), R0 + CSEL CS, R2, R14, R14 + CSEL CS, R3, R15, R15 + STP (R14, R15), 0(R0) + CSEL CS, R4, R16, R16 + CSEL CS, R5, R17, R17 + STP (R16, R17), 16(R0) + CSEL CS, R6, R19, R19 + CSEL CS, R7, R20, R20 + STP (R19, R20), 32(R0) + CSEL CS, R8, R21, R21 + CSEL CS, R9, R22, R22 + STP (R21, R22), 48(R0) + CSEL CS, R10, R23, R23 + CSEL CS, R11, R24, R24 + STP (R23, R24), 64(R0) + CSEL CS, R12, R25, R25 + CSEL CS, R13, R26, R26 + STP (R25, R26), 80(R0) + RET + +// reduce(res *Element) +TEXT ·reduce(SB), NOFRAME|NOSPLIT, $0-8 + LDP ·qElement+0(SB), (R12, R13) + LDP ·qElement+16(SB), (R14, R15) + LDP ·qElement+32(SB), (R16, R17) + LDP ·qElement+48(SB), (R19, R20) + LDP ·qElement+64(SB), (R21, R22) + LDP ·qElement+80(SB), (R23, R24) + MOVD res+0(FP), R25 + LDP 0(R25), (R0, R1) + LDP 16(R25), (R2, R3) + LDP 32(R25), (R4, R5) + LDP 48(R25), (R6, R7) + LDP 64(R25), (R8, R9) + LDP 80(R25), (R10, R11) + + // q = t - q + SUBS R12, R0, R12 + SBCS R13, R1, R13 + SBCS R14, R2, R14 + SBCS R15, R3, R15 + SBCS R16, R4, R16 + SBCS R17, R5, R17 + SBCS R19, R6, R19 + SBCS R20, R7, R20 + SBCS R21, R8, R21 + SBCS R22, R9, R22 + SBCS R23, R10, R23 + SBCS R24, R11, R24 + + // if no borrow, return q, else return t + CSEL CS, R12, R0, R0 + CSEL CS, R13, R1, R1 + STP (R0, R1), 0(R25) + CSEL CS, R14, R2, R2 + CSEL CS, R15, R3, R3 + STP (R2, R3), 16(R25) + CSEL CS, R16, R4, R4 + CSEL CS, R17, R5, R5 + STP (R4, R5), 32(R25) + CSEL CS, R19, R6, R6 + CSEL CS, R20, R7, R7 + STP (R6, R7), 48(R25) + CSEL CS, R21, R8, R8 + CSEL CS, R22, R9, R9 + STP (R8, R9), 64(R25) + CSEL CS, R23, R10, R10 + CSEL CS, R24, R11, R11 + STP (R10, R11), 80(R25) + RET diff --git a/field/generator/asm/arm64/element_butterfly.go b/field/generator/asm/arm64/element_butterfly.go new file mode 100644 index 0000000000..064e3a2843 --- /dev/null +++ b/field/generator/asm/arm64/element_butterfly.go @@ -0,0 +1,47 @@ +package arm64 + +func (f *FFArm64) generateButterfly() { + f.Comment("butterfly(a, b *Element)") + f.Comment("a, b = a+b, a-b") + registers := f.FnHeader("Butterfly", 0, 16) + defer f.AssertCleanStack(0, 0) + + // registers + a := registers.PopN(f.NbWords) + b := registers.PopN(f.NbWords) + r := registers.PopN(f.NbWords) + t := registers.PopN(f.NbWords) + aPtr := registers.Pop() + bPtr := registers.Pop() + + f.LDP("x+0(FP)", aPtr, bPtr) + f.load(aPtr, a) + f.load(bPtr, b) + + for i := 0; i < f.NbWords; i++ { + f.add0n(i)(a[i], b[i], r[i]) + } + + f.SUBS(b[0], a[0], b[0]) + for i := 1; i < f.NbWords; i++ { + f.SBCS(b[i], a[i], b[i]) + } + + for i := 0; i < f.NbWords; i++ { + if i%2 == 0 { + f.LDP(f.qAt(i), a[i], a[i+1]) + } + f.CSEL("CS", "ZR", a[i], t[i]) + } + f.Comment("add q if underflow, 0 if not") + for i := 0; i < f.NbWords; i++ { + f.add0n(i)(b[i], t[i], b[i]) + if i%2 == 1 { + f.STP(b[i-1], b[i], bPtr.At(i-1)) + } + } + + f.reduceAndStore(r, a, aPtr) + + f.RET() +} diff --git a/field/generator/asm/arm64/element_mul.go b/field/generator/asm/arm64/element_mul.go new file mode 100644 index 0000000000..2a2de0ad0c --- /dev/null +++ b/field/generator/asm/arm64/element_mul.go @@ -0,0 +1,259 @@ +// Copyright 2022 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arm64 + +import ( + "github.com/consensys/bavard/arm64" +) + +func (f *FFArm64) generateMul() { + f.Comment("mul(res, x, y *Element)") + f.Comment("Algorithm 2 of Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS") + f.Comment("by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521") + registers := f.FnHeader("mul", 0, 24) + defer f.AssertCleanStack(0, 0) + + fatModulus := f.NbWords > 6 + + xPtr := registers.Pop() + // yPtr := registers.Pop() + bi := registers.Pop() + a := registers.PopN(f.NbWords) + t := registers.PopN(f.NbWords + 1) + + var qInv0, m, yPtr arm64.Register + var q []arm64.Register + ax := xPtr + if fatModulus { + qInv0 = ax + m = bi + yPtr = bi + } else { + qInv0 = registers.Pop() + m = registers.Pop() + q = registers.PopN(f.NbWords) + yPtr = registers.Pop() + } + + qAt := func(i int) arm64.Register { + if !fatModulus { + return q[i] + } + f.MOVD(f.qAt(i), ax) + return ax + } + + divShift := f.Define("divShift", 0, func(args ...arm64.Register) { + if fatModulus { + f.MOVD(f.qInv0(), qInv0) + f.MUL(t[0], qInv0, m) + } + // for j=0 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + + for j := 0; j < f.NbWords; j++ { + f.MUL(qAt(j), m, ax) + f.add0m(j)(ax, t[j], t[j]) + } + f.add0m(f.NbWords)(t[f.NbWords], "ZR", t[f.NbWords]) + + // propagate high bits + f.UMULH(qAt(0), m, ax) + for j := 1; j <= f.NbWords; j++ { + f.add1m(j, true)(ax, t[j], t[j-1]) + if j != f.NbWords { + f.UMULH(qAt(j), m, ax) + } + } + }) + + mulWordN := f.Define("MUL_WORD_N", 0, func(args ...arm64.Register) { + // for j=0 to N-1 + // (C,t[j]) := t[j] + a[j]*b[i] + C + + // lo bits + for j := 0; j < f.NbWords; j++ { + f.MUL(a[j], bi, ax) + f.add0m(j)(ax, t[j], t[j]) + + if j == 0 && !fatModulus { + f.MUL(t[0], qInv0, m) + } + } + f.add0m(f.NbWords)("ZR", "ZR", t[f.NbWords]) + + // propagate high bits + f.UMULH(a[0], bi, ax) + for j := 1; j <= f.NbWords; j++ { + f.add1m(j)(ax, t[j], t[j]) + if j != f.NbWords { + f.UMULH(a[j], bi, ax) + } + } + divShift() + }) + + mulWord0 := f.Define("MUL_WORD_0", 0, func(args ...arm64.Register) { + // for j=0 to N-1 + // (C,t[j]) := t[j] + a[j]*b[i] + C + // lo bits + for j := 0; j < f.NbWords; j++ { + f.MUL(a[j], bi, t[j]) + } + + // propagate high bits + f.UMULH(a[0], bi, ax) + for j := 1; j < f.NbWords; j++ { + f.add1m(j)(ax, t[j], t[j]) + f.UMULH(a[j], bi, ax) + } + f.add1m(f.NbWords)(ax, "ZR", t[f.NbWords]) + if !fatModulus { + f.MUL(t[0], qInv0, m) + } + divShift() + }) + + f.MOVD("y+16(FP)", yPtr) + f.MOVD("x+8(FP)", xPtr) + f.load(xPtr, a) + + for i := 0; i < f.NbWords; i++ { + if fatModulus { + f.MOVD("y+16(FP)", yPtr) + } + f.MOVD(yPtr.At(i), bi) + + if i == 0 { + // load qInv0 and q at first iteration. + if !fatModulus { + f.MOVD(f.qInv0(), qInv0) + for i := 0; i < f.NbWords; i += 2 { + f.LDP(f.qAt(i), q[i], q[i+1]) + } + } + mulWord0() + } else { + mulWordN() + } + } + + if fatModulus { + q = a + for i := 0; i < f.NbWords; i += 2 { + f.LDP(f.qAt(i), q[i], q[i+1]) + } + } + + f.Comment("reduce if necessary") + f.SUBS(q[0], t[0], q[0]) + for i := 1; i < f.NbWords; i++ { + f.SBCS(q[i], t[i], q[i]) + } + + f.MOVD("res+0(FP)", ax) + for i := 0; i < f.NbWords; i++ { + f.CSEL("CS", q[i], t[i], t[i]) + if i%2 == 1 { + f.STP(t[i-1], t[i], ax.At(i-1)) + } + } + + f.RET() +} + +func (f *FFArm64) generateReduce() { + f.Comment("reduce(res *Element)") + registers := f.FnHeader("reduce", 0, 8) + defer f.AssertCleanStack(0, 0) + + // registers + t := registers.PopN(f.NbWords) + q := registers.PopN(f.NbWords) + rPtr := registers.Pop() + + for i := 0; i < f.NbWords; i += 2 { + f.LDP(f.qAt(i), q[i], q[i+1]) + } + + f.MOVD("res+0(FP)", rPtr) + f.load(rPtr, t) + f.reduceAndStore(t, q, rPtr) + + f.RET() +} + +func (f *FFArm64) load(zPtr arm64.Register, z []arm64.Register) { + for i := 0; i < f.NbWords; i += 2 { + f.LDP(zPtr.At(i), z[i], z[i+1]) + } +} + +// q must contain the modulus +// q is modified +// t = t mod q (t must be less than 2q) +// t is stored in zPtr +func (f *FFArm64) reduceAndStore(t, q []arm64.Register, zPtr arm64.Register) { + f.Comment("q = t - q") + f.SUBS(q[0], t[0], q[0]) + for i := 1; i < f.NbWords; i++ { + f.SBCS(q[i], t[i], q[i]) + } + + f.Comment("if no borrow, return q, else return t") + for i := 0; i < f.NbWords; i++ { + f.CSEL("CS", q[i], t[i], t[i]) + if i%2 == 1 { + f.STP(t[i-1], t[i], zPtr.At(i-1)) + } + } +} + +func (f *FFArm64) add0n(i int) func(op1, op2, dst interface{}, comment ...string) { + switch { + case i == 0: + return f.ADDS + case i == f.NbWordsLastIndex: + return f.ADC + default: + return f.ADCS + } +} + +func (f *FFArm64) add0m(i int) func(op1, op2, dst interface{}, comment ...string) { + switch { + case i == 0: + return f.ADDS + case i == f.NbWordsLastIndex+1: + return f.ADC + default: + return f.ADCS + } +} + +func (f *FFArm64) add1m(i int, dumb ...bool) func(op1, op2, dst interface{}, comment ...string) { + switch { + case i == 1: + return f.ADDS + case i == f.NbWordsLastIndex+1: + if len(dumb) == 1 && dumb[0] { + // odd, but it performs better on c8g instances. + return f.ADCS + } + return f.ADC + default: + return f.ADCS + } +} From ed06e35a324a18f6d95d978e9ff5b1c902fa8e4e Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 22 Nov 2024 13:16:11 -0600 Subject: [PATCH 24/74] checkpoint --- field/koalabear/arith.go | 60 ++ field/koalabear/doc.go | 53 ++ field/koalabear/element.go | 389 ++++++++++ field/koalabear/element_purego.go | 64 ++ field/koalabear/element_test.go | 836 +++++++++++++++++++++ field/koalabear/internal/addchain/3f | Bin 0 -> 90 bytes field/koalabear/internal/addchain/3f800000 | Bin 0 -> 212 bytes field/koalabear/internal/main.go | 21 + 8 files changed, 1423 insertions(+) create mode 100644 field/koalabear/arith.go create mode 100644 field/koalabear/doc.go create mode 100644 field/koalabear/element.go create mode 100644 field/koalabear/element_purego.go create mode 100644 field/koalabear/element_test.go create mode 100644 field/koalabear/internal/addchain/3f create mode 100644 field/koalabear/internal/addchain/3f800000 create mode 100644 field/koalabear/internal/main.go diff --git a/field/koalabear/arith.go b/field/koalabear/arith.go new file mode 100644 index 0000000000..0252a5bd1e --- /dev/null +++ b/field/koalabear/arith.go @@ -0,0 +1,60 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package koalabear + +import ( + "math/bits" +) + +// madd0 hi = a*b + c (discards lo bits) +func madd0(a, b, c uint64) (hi uint64) { + var carry, lo uint64 + hi, lo = bits.Mul64(a, b) + _, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +// madd1 hi, lo = a*b + c +func madd1(a, b, c uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +// madd2 hi, lo = a*b + c + d +func madd2(a, b, c, d uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, e, carry) + return +} diff --git a/field/koalabear/doc.go b/field/koalabear/doc.go new file mode 100644 index 0000000000..6c49ce3be1 --- /dev/null +++ b/field/koalabear/doc.go @@ -0,0 +1,53 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// Package koalabear contains field arithmetic operations for modulus = 0x7f000001. +// +// The API is similar to math/big (big.Int), but the operations are significantly faster (up to 20x for the modular multiplication on amd64, see also https://hackmd.io/@gnark/modular_multiplication) +// +// The modulus is hardcoded in all the operations. +// +// Field elements are represented as an array, and assumed to be in Montgomery form in all methods: +// +// type Element [1]uint64 +// +// # Usage +// +// Example API signature: +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element +// +// and can be used like so: +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) +// +// Modulus q = +// +// q[base10] = 2130706433 +// q[base16] = 0x7f000001 +// +// # Warning +// +// This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. +package koalabear diff --git a/field/koalabear/element.go b/field/koalabear/element.go new file mode 100644 index 0000000000..fca7f64f55 --- /dev/null +++ b/field/koalabear/element.go @@ -0,0 +1,389 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package koalabear + +import ( + "crypto/rand" + "encoding/binary" + "io" + "math/big" + "math/bits" + + "github.com/consensys/gnark-crypto/field/pool" +) + +// rSquare where r is the Montgommery constant +// see section 2.3.2 of Tolga Acar's thesis +// https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf +var rSquare = Element{ + 1111325836, +} + + +// Element represents a field element stored on 1 words (uint64) +// +// Element are assumed to be in Montgomery form in all methods. +// +// Modulus q = +// +// q[base10] = 2130706433 +// q[base16] = 0x7f000001 +// +// # Warning +// +// This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. +type Element [1]uint64 + +const ( + Limbs = 1 // number of 64 bits words needed to represent a Element + Bits = 31 // number of bits needed to represent a Element + Bytes = 8 // number of bytes needed to represent a Element +) + +// Field modulus q +const ( + q0 uint64 = 2130706433 + q uint64 = q0 +) + +var qElement = Element{ + q0, +} + +var _modulus big.Int // q stored as big.Int + +// Modulus returns q as a big.Int +// +// q[base10] = 2130706433 +// q[base16] = 0x7f000001 +func Modulus() *big.Int { + return new(big.Int).Set(&_modulus) +} + +// q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +// used for Montgomery reduction +const qInvNeg uint64 = 13906834176474087423 + +func init() { + _modulus.SetString("7f000001", 16) +} + +// NewElement returns a new Element from a uint64 value +// +// it is equivalent to +// +// var v Element +// v.SetUint64(...) +func NewElement(v uint64) Element { + z := Element{v} + z.Mul(&z, &rSquare) + return z +} + +// SetUint64 sets z to v and returns z +func (z *Element) SetUint64(v uint64) *Element { + // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form + *z = Element{v} + return z.Mul(z, &rSquare) // z.toMont() +} + +// SetInt64 sets z to v and returns z +func (z *Element) SetInt64(v int64) *Element { + + // absolute value of v + m := v >> 63 + z.SetUint64(uint64((v ^ m) - m)) + + if m != 0 { + // v is negative + z.Neg(z) + } + + return z +} + +// Set z = x and returns z +func (z *Element) Set(x *Element) *Element { + z[0] = x[0] + return z +} + + +// SetZero z = 0 +func (z *Element) SetZero() *Element { + z[0] = 0 + return z +} + +// SetOne z = 1 (in Montgomery form) +func (z *Element) SetOne() *Element { + z[0] = 402124772 + return z +} + + +// Equal returns z == x; constant-time +func (z *Element) Equal(x *Element) bool { + return z.NotEqual(x) == 0 +} + +// NotEqual returns 0 if and only if z == x; constant-time +func (z *Element) NotEqual(x *Element) uint64 { + return (z[0] ^ x[0]) +} + +// IsZero returns z == 0 +func (z *Element) IsZero() bool { + return (z[0]) == 0 +} + +// IsOne returns z == 1 +func (z *Element) IsOne() bool { + return z[0] == 402124772 +} + + +// SetRandom sets z to a uniform random value in [0, q). +// +// This might error only if reading from crypto/rand.Reader errors, +// in which case, value of z is undefined. +func (z *Element) SetRandom() (*Element, error) { + // this code is generated for all modulus + // and derived from go/src/crypto/rand/util.go + + // l is number of limbs * 8; the number of bytes needed to reconstruct 1 uint64 + const l = 8 + + // bitLen is the maximum bit length needed to encode a value < q. + const bitLen = 31 + + // k is the maximum byte length needed to encode a value < q. + const k = (bitLen + 7) / 8 + + // b is the number of bits in the most significant byte of q-1. + b := uint(bitLen % 8) + if b == 0 { + b = 8 + } + + var bytes [l]byte + + for { + // note that bytes[k:l] is always 0 + if _, err := io.ReadFull(rand.Reader, bytes[:k]); err != nil { + return nil, err + } + + // Clear unused bits in in the most significant byte to increase probability + // that the candidate is < q. + bytes[k-1] &= uint8(int(1<> 1 + z[0] >>= 1 + +} + +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// sets and returns z = z * 1 +func (z *Element) fromMont() *Element { + fromMont(z) + return z +} + +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + z[0], _ = bits.Add64(x[0], y[0], 0) + if z[0] >= q { + z[0] -= q + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + if x[0]&(1<<63) == (1 << 63) { + // if highest bit is set, then we have a carry to x + x, we shift and subtract q + z[0] = (x[0] << 1) - q + } else { + // highest bit is not set, but x + x can still be >= q + z[0] = (x[0] << 1) + if z[0] >= q { + z[0] -= q + } + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + if b != 0 { + z[0] += q + } + return z +} + +// Neg z = q - x +func (z *Element) Neg(x *Element) *Element { + if x.IsZero() { + z.SetZero() + return z + } + z[0] = q - x[0] + return z +} + + +func _fromMontGeneric(z *Element) { + // the following lines implement z = z * 1 + // with a modified CIOS montgomery multiplication + // see Mul for algorithm documentation + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + z[0] = C + } + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + z[0] -= q + } +} + +func _reduceGeneric(z *Element) { + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + z[0] -= q + } +} + + +// toMont converts z to Montgomery form +// sets and returns z = z * r² +func (z *Element) toMont() *Element { + return z.Mul(z, &rSquare) +} + + +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[0:8], z[0]) + + return res.SetBytes(b[:]) +} + +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) +} + +// ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead +func (z Element) ToBigIntRegular(res *big.Int) *big.Int { + z.fromMont() + return z.toBigInt(res) +} + + +// SetBigInt sets z to v and returns z +func (z *Element) SetBigInt(v *big.Int) *Element { + z.SetZero() + + var zero big.Int + + // fast path + c := v.Cmp(&_modulus) + if c == 0 { + // v == 0 + return z + } else if c != 1 && v.Cmp(&zero) != -1 { + // 0 < v < q + return z.setBigInt(v) + } + + // get temporary big int from the pool + vv := pool.BigInt.Get() + + // copy input + modular reduction + vv.Mod(v, &_modulus) + + // set big int byte value + z.setBigInt(vv) + + // release object into pool + pool.BigInt.Put(vv) + return z +} + +// setBigInt assumes 0 ⩽ v < q +func (z *Element) setBigInt(v *big.Int) *Element { + vBits := v.Bits() + + if bits.UintSize == 64 { + for i := 0; i < len(vBits); i++ { + z[i] = uint64(vBits[i]) + } + } else { + for i := 0; i < len(vBits); i++ { + if i%2 == 0 { + z[i/2] = uint64(vBits[i]) + } else { + z[i/2] |= uint64(vBits[i]) << 32 + } + } + } + + return z.toMont() +} diff --git a/field/koalabear/element_purego.go b/field/koalabear/element_purego.go new file mode 100644 index 0000000000..21b1eb5493 --- /dev/null +++ b/field/koalabear/element_purego.go @@ -0,0 +1,64 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package koalabear + +import "math/bits" + + + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // In fact, since the modulus R fits on one register, the CIOS algorithm gets reduced to standard REDC (textbook Montgomery reduction): + // hi, lo := x * y + // m := (lo * qInvNeg) mod R + // (*) r := (hi * R + lo + m * q) / R + // reduce r if necessary + + // On the emphasized line, we get r = hi + (lo + m * q) / R + // If we write hi2, lo2 = m * q then R | m * q - lo2 ⇒ R | (lo * qInvNeg) q - lo2 = -lo - lo2 + // This shows lo + lo2 = 0 mod R. i.e. lo + lo2 = 0 if lo = 0 and R otherwise. + // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) + // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. + + var r uint64 + hi, lo := bits.Mul64(x[0], y[0]) + if lo != 0 { + hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow + } + m := lo * qInvNeg + hi2, _ := bits.Mul64(m, q) + r, carry := bits.Add64(hi2, hi, 0) + + if carry != 0 || r >= q { + // we need to reduce + r -= q + } + z[0] = r + + return z +} diff --git a/field/koalabear/element_test.go b/field/koalabear/element_test.go new file mode 100644 index 0000000000..bacc70706c --- /dev/null +++ b/field/koalabear/element_test.go @@ -0,0 +1,836 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package koalabear + +import ( + "math/big" + "math/bits" + + "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" + + "github.com/stretchr/testify/require" +) + +// ------------------------------------------------------------------------------------------------- +// benchmarks +// most benchmarks are rudimentary and should sample a large number of random inputs +// or be run multiple times to ensure it didn't measure the fastest path of the function + +var benchResElement Element + + +func BenchmarkElementDouble(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Double(&benchResElement) + } +} + +func BenchmarkElementAdd(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Add(&x, &benchResElement) + } +} + +func BenchmarkElementSub(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Sub(&x, &benchResElement) + } +} + +func BenchmarkElementNeg(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Neg(&benchResElement) + } +} + +func BenchmarkElementFromMont(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.fromMont() + } +} + + + + +func BenchmarkElementMul(b *testing.B) { + x := Element{ + 1111325836, + } + benchResElement.SetOne() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Mul(&benchResElement, &x) + } +} + + +func TestElementNegZero(t *testing.T) { + var a, b Element + b.SetZero() + for a.IsZero() { + a.SetRandom() + } + a.Neg(&b) + if !a.IsZero() { + t.Fatal("neg(0) != 0") + } +} + +// ------------------------------------------------------------------------------------------------- +// Gopter tests +// most of them are generated with a template + +const ( + nbFuzzShort = 200 + nbFuzz = 1000 +) + +// special values to be used in tests +var staticTestValues []Element + +func init() { + staticTestValues = append(staticTestValues, Element{}) // zero + staticTestValues = append(staticTestValues, One()) // one + staticTestValues = append(staticTestValues, rSquare) // r² + var e, one Element + one.SetOne() + e.Sub(&qElement, &one) + staticTestValues = append(staticTestValues, e) // q - 1 + e.Double(&one) + staticTestValues = append(staticTestValues, e) // 2 + + { + a := qElement + a[0]-- + staticTestValues = append(staticTestValues, a) + } + staticTestValues = append(staticTestValues, Element{0}) + staticTestValues = append(staticTestValues, Element{1}) + staticTestValues = append(staticTestValues, Element{2}) + + { + a := qElement + a[0]-- + staticTestValues = append(staticTestValues, a) + } + + { + a := qElement + a[0] = 0 + staticTestValues = append(staticTestValues, a) + } + +} + +func TestElementReduce(t *testing.T) { + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + s := testValues[i] + expected := s + reduce(&s) + _reduceGeneric(&expected) + if !s.Equal(&expected) { + t.Fatal("reduce failed: asm and generic impl don't match") + } + } + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := genFull() + + properties.Property("reduce should output a result smaller than modulus", prop.ForAll( + func(a Element) bool { + b := a + reduce(&a) + _reduceGeneric(&b) + return a.smallerThanModulus() && a.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementEqual(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("x.Equal(&y) iff x == y; likely false for random pairs", prop.ForAll( + func(a testPairElement, b testPairElement) bool { + return a.element.Equal(&b.element) == (a.element == b.element) + }, + genA, + genB, + )) + + properties.Property("x.Equal(&y) if x == y", prop.ForAll( + func(a testPairElement) bool { + b := a.element + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + + +func TestElementAdd(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Add: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Add(&a.element, &b.element) + a.element.Add(&a.element, &b.element) + b.element.Add(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Add: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Add(&a.element, &b.element) + + var d, e big.Int + d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Add(&a.element, &r) + d.Add(&a.bigint, &rb).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Add: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Add(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Add(&a, &b) + d.Add(&aBig, &bBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Add failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSub(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Sub: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Sub(&a.element, &b.element) + a.element.Sub(&a.element, &b.element) + b.element.Sub(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Sub: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Sub(&a.element, &b.element) + + var d, e big.Int + d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Sub(&a.element, &r) + d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Sub: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Sub(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Sub(&a, &b) + d.Sub(&aBig, &bBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Sub failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementMul(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Mul: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Mul(&a.element, &b.element) + a.element.Mul(&a.element, &b.element) + b.element.Mul(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Mul: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Mul(&a.element, &b.element) + + var d, e big.Int + d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Mul(&a.element, &r) + d.Mul(&a.bigint, &rb).Mod(&d, Modulus()) + + // checking generic impl against asm path + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Mul: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Mul(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Mul(&a, &b) + d.Mul(&aBig, &bBig).Mod(&d, Modulus()) + + // checking asm against generic impl + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Mul failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + + +func TestElementDouble(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Double: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Double(&a.element) + a.element.Double(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Double: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Double(&a.element) + + var d, e big.Int + d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Double: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Double(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Double(&a) + + var d, e big.Int + d.Lsh(&aBig, 1).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Double failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementNeg(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Neg: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Neg(&a.element) + a.element.Neg(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Neg: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Neg(&a.element) + + var d, e big.Int + d.Neg(&a.bigint).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Neg: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Neg(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Neg(&a) + + var d, e big.Int + d.Neg(&aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Neg failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementNewElement(t *testing.T) { + assert := require.New(t) + + t.Parallel() + + e := NewElement(1) + assert.True(e.IsOne()) + + e = NewElement(0) + assert.True(e.IsZero()) +} + + +func TestElementFromMont(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Assembly implementation must be consistent with generic one", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.fromMont() + _fromMontGeneric(&d) + return c.Equal(&d) + }, + genA, + )) + + properties.Property("x.fromMont().toMont() == x", prop.ForAll( + func(a testPairElement) bool { + c := a.element + c.fromMont().toMont() + return c.Equal(&a.element) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +type testPairElement struct { + element Element + bigint big.Int +} + +func gen() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var g testPairElement + + g.element = Element{ + genParams.NextUint64(), + } + if qElement[0] != ^uint64(0) { + g.element[0] %= (qElement[0] + 1) + } + + for !g.element.smallerThanModulus() { + g.element = Element{ + genParams.NextUint64(), + } + if qElement[0] != ^uint64(0) { + g.element[0] %= (qElement[0] + 1) + } + } + + g.element.BigInt(&g.bigint) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element + + g = Element{ + genParams.NextUint64(), + } + + if qElement[0] != ^uint64(0) { + g[0] %= (qElement[0] + 1) + } + + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + } + if qElement[0] != ^uint64(0) { + g[0] %= (qElement[0] + 1) + } + } + + return g +} + +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + + var carry uint64 + a[0], _ = bits.Add64(a[0], qElement[0], carry) + + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} diff --git a/field/koalabear/internal/addchain/3f b/field/koalabear/internal/addchain/3f new file mode 100644 index 0000000000000000000000000000000000000000..db4d062eb3428f2f32be62bc818f17072bee7103 GIT binary patch literal 90 zcmWlPISzmz00eiCAR)xk#@@eZqhGMHG8#VM@o|J^voo1KzXsYhoqO62@+a~m4R-@( exhhbC2`iY(N@Iz?Fvxar+|J=Z)O^l7PWeeN3 zv1=E5_AoKQfdd>m#E~N$JI09 Date: Sat, 23 Nov 2024 11:13:54 -0600 Subject: [PATCH 25/74] test passing --- field/generator/config/field_config.go | 57 +- .../internal/templates/element/arith.go | 48 +- .../internal/templates/element/base.go | 74 +- .../internal/templates/element/conv.go | 44 +- .../internal/templates/element/inverse.go | 2 +- .../internal/templates/element/mul_cios.go | 30 +- .../internal/templates/element/tests.go | 22 +- .../templates/element/tests_vector.go | 8 +- .../internal/templates/element/vector.go | 6 +- field/koalabear/arith.go | 48 +- field/koalabear/element.go | 675 ++++++- field/koalabear/element_purego.go | 71 +- field/koalabear/element_test.go | 1748 +++++++++++++++-- 13 files changed, 2481 insertions(+), 352 deletions(-) diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index f6e5513b4c..4e6e9e862c 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -73,6 +73,8 @@ type FieldConfig struct { SqrtQ3Mod4ExponentData *addchain.AddChainData UseAddChain bool + Word Word // 32 iff Q < 2^32, else 64 + // asm code generation GenerateOpsAMD64 bool GenerateOpsARM64 bool @@ -80,6 +82,17 @@ type FieldConfig struct { GenerateVectorOpsARM64 bool } +type Word struct { + BitSize int // 32 or 64 + ByteSize int // 4 or 8 + TypeLower string // uint32 or uint64 + TypeUpper string // Uint32 or Uint64 + Add string // Add64 or Add32 + Sub string // Sub64 or Sub32 + Mul string // Mul64 or Mul32 + Len string // Len64 or Len32 +} + // NewFieldConfig returns a data structure with needed information to generate apis for field element // // See field/generator package @@ -102,8 +115,6 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) // pre compute field constants F.NbBits = bModulus.BitLen() F.NbWords = len(bModulus.Bits()) - F.NbBytes = F.NbWords * 8 // (F.NbBits + 7) / 8 - F.NbWordsLastIndex = F.NbWords - 1 // set q from big int repr @@ -114,9 +125,36 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) _qHalved.Sub(&bModulus, bOne).Rsh(_qHalved, 1).Add(_qHalved, bOne) F.QMinusOneHalvedP = toUint64Slice(_qHalved, F.NbWords) + // Word size; we pick uint32 only if the modulus is less than 2^32 + F.Word.BitSize = 64 + F.Word.ByteSize = 8 + F.Word.TypeLower = "uint64" + F.Word.TypeUpper = "Uint64" + F.Word.Add = "Add64" + F.Word.Sub = "Sub64" + F.Word.Mul = "Mul64" + F.Word.Len = "Len64" + if F.NbBits < 32 { + F.Word.BitSize = 32 + F.Word.ByteSize = 4 + F.Word.TypeLower = "uint32" + F.Word.TypeUpper = "Uint32" + F.Word.Add = "Add32" + F.Word.Sub = "Sub32" + F.Word.Mul = "Mul32" + F.Word.Len = "Len32" + } + + F.NbBytes = F.NbWords * F.Word.ByteSize + // setting qInverse + radix := uint(64) + if F.Word.BitSize == 32 { + radix = 32 + } + _r := big.NewInt(1) - _r.Lsh(_r, uint(F.NbWords)*64) + _r.Lsh(_r, uint(F.NbWords)*radix) _rInv := big.NewInt(1) _qInv := big.NewInt(0) extendedEuclideanAlgo(_r, &bModulus, _rInv, _qInv) @@ -140,24 +178,25 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) { c := F.NbWords * 64 - F.UsingP20Inverse = F.NbWords > 1 && F.NbBits < c + // TODO @gbotrel check inverse performance for 32 bits + F.UsingP20Inverse = F.NbWords > 1 && F.NbBits < c && F.Word.BitSize == 64 } // rsquare _rSquare := big.NewInt(2) - exponent := big.NewInt(int64(F.NbWords) * 64 * 2) + exponent := big.NewInt(int64(F.NbWords) * int64(radix) * 2) _rSquare.Exp(_rSquare, exponent, &bModulus) F.RSquare = toUint64Slice(_rSquare, F.NbWords) var one big.Int one.SetUint64(1) - one.Lsh(&one, uint(F.NbWords)*64).Mod(&one, &bModulus) + one.Lsh(&one, uint(F.NbWords)*radix).Mod(&one, &bModulus) F.One = toUint64Slice(&one, F.NbWords) { var n big.Int n.SetUint64(13) - n.Lsh(&n, uint(F.NbWords)*64).Mod(&n, &bModulus) + n.Lsh(&n, uint(F.NbWords)*radix).Mod(&n, &bModulus) F.Thirteen = toUint64Slice(&n, F.NbWords) } @@ -246,7 +285,7 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) var g big.Int g.Exp(&nonResidue, &s, &bModulus) // store g in montgomery form - g.Lsh(&g, uint(F.NbWords)*64).Mod(&g, &bModulus) + g.Lsh(&g, uint(F.NbWords)*radix).Mod(&g, &bModulus) F.SqrtG = toUint64Slice(&g, F.NbWords) // store non residue in montgomery form @@ -342,7 +381,7 @@ func (f *FieldConfig) StringToMont(str string) big.Int { func (f *FieldConfig) ToMont(nonMont big.Int) big.Int { var mont big.Int - mont.Lsh(&nonMont, uint(f.NbWords)*64) + mont.Lsh(&nonMont, uint(f.NbWords)*uint(f.Word.BitSize)) mont.Mod(&mont, f.ModulusBig) return mont } diff --git a/field/generator/internal/templates/element/arith.go b/field/generator/internal/templates/element/arith.go index 06a7805588..427529f2a8 100644 --- a/field/generator/internal/templates/element/arith.go +++ b/field/generator/internal/templates/element/arith.go @@ -6,42 +6,42 @@ import ( ) // madd0 hi = a*b + c (discards lo bits) -func madd0(a, b, c uint64) (hi uint64) { - var carry, lo uint64 - hi, lo = bits.Mul64(a, b) - _, carry = bits.Add64(lo, c, 0) - hi, _ = bits.Add64(hi, 0, carry) +func madd0(a, b, c {{$.Word.TypeLower}}) (hi {{$.Word.TypeLower}}) { + var carry, lo {{$.Word.TypeLower}} + hi, lo = bits.{{$.Word.Mul}}(a, b) + _, carry = bits.{{$.Word.Add}}(lo, c, 0) + hi, _ = bits.{{$.Word.Add}}(hi, 0, carry) return } // madd1 hi, lo = a*b + c -func madd1(a, b, c uint64) (hi uint64, lo uint64) { - var carry uint64 - hi, lo = bits.Mul64(a, b) - lo, carry = bits.Add64(lo, c, 0) - hi, _ = bits.Add64(hi, 0, carry) +func madd1(a, b, c {{$.Word.TypeLower}}) (hi {{$.Word.TypeLower}}, lo {{$.Word.TypeLower}}) { + var carry {{$.Word.TypeLower}} + hi, lo = bits.{{$.Word.Mul}}(a, b) + lo, carry = bits.{{$.Word.Add}}(lo, c, 0) + hi, _ = bits.{{$.Word.Add}}(hi, 0, carry) return } // madd2 hi, lo = a*b + c + d -func madd2(a, b, c, d uint64) (hi uint64, lo uint64) { - var carry uint64 - hi, lo = bits.Mul64(a, b) - c, carry = bits.Add64(c, d, 0) - hi, _ = bits.Add64(hi, 0, carry) - lo, carry = bits.Add64(lo, c, 0) - hi, _ = bits.Add64(hi, 0, carry) +func madd2(a, b, c, d {{$.Word.TypeLower}}) (hi {{$.Word.TypeLower}}, lo {{$.Word.TypeLower}}) { + var carry {{$.Word.TypeLower}} + hi, lo = bits.{{$.Word.Mul}}(a, b) + c, carry = bits.{{$.Word.Add}}(c, d, 0) + hi, _ = bits.{{$.Word.Add}}(hi, 0, carry) + lo, carry = bits.{{$.Word.Add}}(lo, c, 0) + hi, _ = bits.{{$.Word.Add}}(hi, 0, carry) return } -func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { - var carry uint64 - hi, lo = bits.Mul64(a, b) - c, carry = bits.Add64(c, d, 0) - hi, _ = bits.Add64(hi, 0, carry) - lo, carry = bits.Add64(lo, c, 0) - hi, _ = bits.Add64(hi, e, carry) +func madd3(a, b, c, d, e {{$.Word.TypeLower}}) (hi {{$.Word.TypeLower}}, lo {{$.Word.TypeLower}}) { + var carry {{$.Word.TypeLower}} + hi, lo = bits.{{$.Word.Mul}}(a, b) + c, carry = bits.{{$.Word.Add}}(c, d, 0) + hi, _ = bits.{{$.Word.Add}}(hi, 0, carry) + lo, carry = bits.{{$.Word.Add}}(lo, c, 0) + hi, _ = bits.{{$.Word.Add}}(hi, e, carry) return } diff --git a/field/generator/internal/templates/element/base.go b/field/generator/internal/templates/element/base.go index dd05519baf..22c15e9af8 100644 --- a/field/generator/internal/templates/element/base.go +++ b/field/generator/internal/templates/element/base.go @@ -18,7 +18,7 @@ import ( "github.com/bits-and-blooms/bitset" ) -// {{.ElementName}} represents a field element stored on {{.NbWords}} words (uint64) +// {{.ElementName}} represents a field element stored on {{.NbWords}} words ({{$.Word.TypeLower}}) // // {{.ElementName}} are assumed to be in Montgomery form in all methods. // @@ -30,10 +30,10 @@ import ( // Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. -type {{.ElementName}} [{{.NbWords}}]uint64 +type {{.ElementName}} [{{.NbWords}}]{{$.Word.TypeLower}} const ( - Limbs = {{.NbWords}} // number of 64 bits words needed to represent a {{.ElementName}} + Limbs = {{.NbWords}} // number of {{$.Word.TypeLower}} words needed to represent a {{.ElementName}} Bits = {{.NbBits}} // number of bits needed to represent a {{.ElementName}} Bytes = {{.NbBytes}} // number of bytes needed to represent a {{.ElementName}} ) @@ -42,9 +42,9 @@ const ( // Field modulus q const ( {{- range $i := $.NbWordsIndexesFull}} - q{{$i}} uint64 = {{index $.Q $i}} + q{{$i}} {{$.Word.TypeLower}} = {{index $.Q $i}} {{- if eq $.NbWords 1}} - q uint64 = q0 + q {{$.Word.TypeLower}} = q0 {{- end}} {{- end}} ) @@ -66,7 +66,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = {{index .QInverse 0}} +const qInvNeg {{$.Word.TypeLower}} = {{index .QInverse 0}} {{- if eq .NbWords 4}} // mu = 2^288 / q needed for partial Barrett reduction @@ -84,16 +84,29 @@ func init() { // var v {{.ElementName}} // v.SetUint64(...) func New{{.ElementName}}(v uint64) {{.ElementName}} { - z := {{.ElementName}}{v} +{{- if eq .Word.BitSize 32}} + z := {{.ElementName}}{ {{$.Word.TypeLower}}(v % uint64(q0)) } z.Mul(&z, &rSquare) return z +{{- else }} + z := {{.ElementName}}{ {{$.Word.TypeLower}}(v) } + z.Mul(&z, &rSquare) + return z +{{- end}} } // SetUint64 sets z to v and returns z func (z *{{.ElementName}}) SetUint64(v uint64) *{{.ElementName}} { +{{- if eq .Word.BitSize 32}} // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form - *z = {{.ElementName}}{v} + *z = {{.ElementName}}{ {{$.Word.TypeLower}}(v % uint64(q0)) } return z.Mul(z, &rSquare) // z.toMont() +{{- else }} + // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form + *z = {{.ElementName}}{ {{$.Word.TypeLower}}(v) } + return z.Mul(z, &rSquare) // z.toMont() +{{- end}} + } // SetInt64 sets z to v and returns z @@ -211,7 +224,7 @@ func (z *{{.ElementName}}) Equal(x *{{.ElementName}}) bool { // NotEqual returns 0 if and only if z == x; constant-time func (z *{{.ElementName}}) NotEqual(x *{{.ElementName}}) uint64 { -return {{- range $i := reverse .NbWordsIndexesNoZero}}(z[{{$i}}] ^ x[{{$i}}]) | {{end}}(z[0] ^ x[0]) +return uint64({{- range $i := reverse .NbWordsIndexesNoZero}}(z[{{$i}}] ^ x[{{$i}}]) | {{end}}(z[0] ^ x[0])) } // IsZero returns z == 0 @@ -241,7 +254,7 @@ func (z *{{.ElementName}}) IsUint64() bool { // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *{{.ElementName}}) Uint64() uint64 { - return z.Bits()[0] + return uint64(z.Bits()[0]) } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -283,10 +296,10 @@ func (z *{{.ElementName}}) LexicographicallyLargest() bool { _z := z.Bits() - var b uint64 - _, b = bits.Sub64(_z[0], {{index .QMinusOneHalvedP 0}}, 0) + var b {{$.Word.TypeLower}} + _, b = bits.{{$.Word.Sub}}(_z[0], {{index .QMinusOneHalvedP 0}}, 0) {{- range $i := .NbWordsIndexesNoZero}} - _, b = bits.Sub64(_z[{{$i}}], {{index $.QMinusOneHalvedP $i}}, b) + _, b = bits.{{$.Word.Sub}}(_z[{{$i}}], {{index $.QMinusOneHalvedP $i}}, b) {{- end}} return b == 0 @@ -323,13 +336,16 @@ func (z *{{.ElementName}}) SetRandom() (*{{.ElementName}}, error) { return nil, err } + {{- if eq $.Word.BitSize 32}} + // TODO @gbotrel check this is correct with 32bits words + {{- end}} // Clear unused bits in in the most significant byte to increase probability // that the candidate is < q. bytes[k-1] &= uint8(int(1<= q - z[0] = (x[0] << 1) - if z[0] >= q { - z[0] -= q - } + z[0] = x[0] << 1 + if z[0] >= q { + z[0] -= q } {{- else}} {{ $hasCarry := or (not $.NoCarry) (gt $.NbWords 1)}} @@ -478,10 +488,10 @@ func (z *{{.ElementName}}) Double( x *{{.ElementName}}) *{{.ElementName}} { // Sub z = x - y (mod q) func (z *{{.ElementName}}) Sub( x, y *{{.ElementName}}) *{{.ElementName}} { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) + var b {{$.Word.TypeLower}} + z[0], b = bits.{{$.Word.Sub}}(x[0], y[0], 0) {{- range $i := .NbWordsIndexesNoZero}} - z[{{$i}}], b = bits.Sub64(x[{{$i}}], y[{{$i}}], b) + z[{{$i}}], b = bits.{{$.Word.Sub}}(x[{{$i}}], y[{{$i}}], b) {{- end}} if b != 0 { {{- if eq .NbWords 1}} @@ -528,7 +538,7 @@ func (z *{{.ElementName}}) Neg( x *{{.ElementName}}) *{{.ElementName}} { // Select is a constant-time conditional move. // If c=0, z = x0. Else z = x1 func (z *{{.ElementName}}) Select(c int, x0 *{{.ElementName}}, x1 *{{.ElementName}}) *{{.ElementName}} { - cC := uint64( (int64(c) | -int64(c)) >> 63 ) // "canonicized" into: 0 if c=0, -1 otherwise + cC := {{$.Word.TypeLower}}( (int64(c) | -int64(c)) >> 63 ) // "canonicized" into: 0 if c=0, -1 otherwise {{- range $i := .NbWordsIndexesFull }} z[{{$i}}] = x0[{{$i}}] ^ cC & (x0[{{$i}}] ^ x1[{{$i}}]) {{- end}} @@ -612,10 +622,10 @@ func _butterflyGeneric(a, b *{{.ElementName}}) { func (z *{{.ElementName}}) BitLen() int { {{- range $i := reverse .NbWordsIndexesNoZero}} if z[{{$i}}] != 0 { - return {{mul $i 64}} + bits.Len64(z[{{$i}}]) + return {{mul $i 64}} + bits.{{$.Word.Len}}(z[{{$i}}]) } {{- end}} - return bits.Len64(z[0]) + return bits.{{$.Word.Len}}(z[0]) } // Hash msg to count prime field elements. diff --git a/field/generator/internal/templates/element/conv.go b/field/generator/internal/templates/element/conv.go index 9c40ffec79..0557c0d3f1 100644 --- a/field/generator/internal/templates/element/conv.go +++ b/field/generator/internal/templates/element/conv.go @@ -26,11 +26,11 @@ func (z *{{.ElementName}}) String() string { func (z *{{.ElementName}}) toBigInt(res *big.Int) *big.Int { var b [Bytes]byte {{- range $i := reverse .NbWordsIndexesFull}} - {{- $j := mul $i 8}} + {{- $j := mul $i $.Word.ByteSize}} {{- $k := sub $.NbWords 1}} {{- $k := sub $k $i}} - {{- $jj := add $j 8}} - binary.BigEndian.PutUint64(b[{{$j}}:{{$jj}}], z[{{$k}}]) + {{- $jj := add $j $.Word.ByteSize}} + binary.BigEndian.Put{{$.Word.TypeUpper}}(b[{{$j}}:{{$jj}}], z[{{$k}}]) {{- end}} return res.SetBytes(b[:]) @@ -61,12 +61,12 @@ func (z *{{.ElementName}}) Text(base int) string { zzNeg.Neg(z) zzNeg.fromMont() if zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { - return "-" + strconv.FormatUint(zzNeg[0], base) + return "-" + strconv.FormatUint(uint64(zzNeg[0]), base) } } {{- end}} zz := z.Bits() - return strconv.FormatUint(zz[0], base) + return strconv.FormatUint(uint64(zz[0]), base) {{- else }} if base == 10 { var zzNeg {{.ElementName}} @@ -103,10 +103,10 @@ func (z {{.ElementName}}) ToBigIntRegular(res *big.Int) *big.Int { return z.toBigInt(res) } -// Bits provides access to z by returning its value as a little-endian [{{.NbWords}}]uint64 array. +// Bits provides access to z by returning its value as a little-endian [{{.NbWords}}]{{.Word.TypeLower}} array. // Bits is intended to support implementation of missing low-level {{.ElementName}} // functionality outside this package; it should be avoided otherwise. -func (z *{{.ElementName}}) Bits() [{{.NbWords}}]uint64 { +func (z *{{.ElementName}}) Bits() [{{.NbWords}}]{{.Word.TypeLower}} { _z := *z fromMont(&_z) return _z @@ -207,14 +207,14 @@ func (z *{{.ElementName}}) setBigInt(v *big.Int) *{{.ElementName}} { if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { - z[i] = uint64(vBits[i]) + z[i] = {{$.Word.TypeLower}}(vBits[i]) } } else { for i := 0; i < len(vBits); i++ { if i%2 == 0 { - z[i/2] = uint64(vBits[i]) + z[i/2] = {{$.Word.TypeLower}}(vBits[i]) } else { - z[i/2] |= uint64(vBits[i]) << 32 + z[i/2] |= {{$.Word.TypeLower}}(vBits[i]) << 32 } } } @@ -323,11 +323,11 @@ type bigEndian struct{} func (bigEndian) Element(b *[Bytes]byte) ({{.ElementName}}, error) { var z {{.ElementName}} {{- range $i := reverse .NbWordsIndexesFull}} - {{- $j := mul $i 8}} + {{- $j := mul $i $.Word.ByteSize}} {{- $k := sub $.NbWords 1}} {{- $k := sub $k $i}} - {{- $jj := add $j 8}} - z[{{$k}}] = binary.BigEndian.Uint64((*b)[{{$j}}:{{$jj}}]) + {{- $jj := add $j $.Word.ByteSize}} + z[{{$k}}] = binary.BigEndian.{{$.Word.TypeUpper}}((*b)[{{$j}}:{{$jj}}]) {{- end}} if !z.smallerThanModulus() { @@ -342,11 +342,11 @@ func (bigEndian) PutElement(b *[Bytes]byte, e {{.ElementName}}) { e.fromMont() {{- range $i := reverse .NbWordsIndexesFull}} - {{- $j := mul $i 8}} + {{- $j := mul $i $.Word.ByteSize}} {{- $k := sub $.NbWords 1}} {{- $k := sub $k $i}} - {{- $jj := add $j 8}} - binary.BigEndian.PutUint64((*b)[{{$j}}:{{$jj}}], e[{{$k}}]) + {{- $jj := add $j $.Word.ByteSize}} + binary.BigEndian.Put{{$.Word.TypeUpper}}((*b)[{{$j}}:{{$jj}}], e[{{$k}}]) {{- end}} } @@ -362,9 +362,9 @@ type littleEndian struct{} func (littleEndian) Element(b *[Bytes]byte) ({{.ElementName}}, error) { var z {{.ElementName}} {{- range $i := .NbWordsIndexesFull}} - {{- $j := mul $i 8}} - {{- $jj := add $j 8}} - z[{{$i}}] = binary.LittleEndian.Uint64((*b)[{{$j}}:{{$jj}}]) + {{- $j := mul $i $.Word.ByteSize}} + {{- $jj := add $j $.Word.ByteSize}} + z[{{$i}}] = binary.LittleEndian.{{$.Word.TypeUpper}}((*b)[{{$j}}:{{$jj}}]) {{- end}} if !z.smallerThanModulus() { @@ -379,9 +379,9 @@ func (littleEndian) PutElement(b *[Bytes]byte, e {{.ElementName}}) { e.fromMont() {{- range $i := .NbWordsIndexesFull}} - {{- $j := mul $i 8}} - {{- $jj := add $j 8}} - binary.LittleEndian.PutUint64((*b)[{{$j}}:{{$jj}}], e[{{$i}}]) + {{- $j := mul $i $.Word.ByteSize}} + {{- $jj := add $j $.Word.ByteSize}} + binary.LittleEndian.Put{{$.Word.TypeUpper}}((*b)[{{$j}}:{{$jj}}], e[{{$i}}]) {{- end}} } diff --git a/field/generator/internal/templates/element/inverse.go b/field/generator/internal/templates/element/inverse.go index 6adaedc557..a9cd2ff13d 100644 --- a/field/generator/internal/templates/element/inverse.go +++ b/field/generator/internal/templates/element/inverse.go @@ -20,7 +20,7 @@ if b != 0 { {{/* We use big.Int for Inverse for these type of moduli */}} {{if not $.UsingP20Inverse}} -{{- if eq .NbWords 1}} +{{- if and (eq .NbWords 1) (eq .Word.BitSize 64)}} // Inverse z = x⁻¹ (mod q) // // if x == 0, sets and returns z = x diff --git a/field/generator/internal/templates/element/mul_cios.go b/field/generator/internal/templates/element/mul_cios.go index 6fb554c3bd..7d19d86493 100644 --- a/field/generator/internal/templates/element/mul_cios.go +++ b/field/generator/internal/templates/element/mul_cios.go @@ -39,15 +39,15 @@ package element // The same (arm64) unrolled Go code produce satisfying performance for WASM (compiled using TinyGo). const MulCIOS = ` {{ define "mul_cios" }} - var t [{{add .all.NbWords 1}}]uint64 - var D uint64 - var m, C uint64 + var t [{{add .all.NbWords 1}}]{{$.all.Word.TypeLower}} + var D {{$.all.Word.TypeLower}} + var m, C {{$.all.Word.TypeLower}} {{- range $j := .all.NbWordsIndexesFull}} // ----------------------------------- // First loop {{ if eq $j 0}} - C, t[0] = bits.Mul64({{$.V2}}[{{$j}}], {{$.V1}}[0]) + C, t[0] = bits.{{$.all.Word.Mul}}({{$.V2}}[{{$j}}], {{$.V1}}[0]) {{- range $i := $.all.NbWordsIndexesNoZero}} C, t[{{$i}}] = madd1({{$.V2}}[{{$j}}], {{$.V1}}[{{$i}}], C) {{- end}} @@ -57,7 +57,7 @@ const MulCIOS = ` C, t[{{$i}}] = madd2({{$.V2}}[{{$j}}], {{$.V1}}[{{$i}}], t[{{$i}}], C) {{- end}} {{ end }} - t[{{$.all.NbWords}}], D = bits.Add64(t[{{$.all.NbWords}}], C, 0) + t[{{$.all.NbWords}}], D = bits.{{$.all.Word.Add}}(t[{{$.all.NbWords}}], C, 0) // m = t[0]n'[0] mod W m = t[0] * qInvNeg @@ -69,22 +69,22 @@ const MulCIOS = ` C, t[{{sub $i 1}}] = madd2(m, q{{$i}}, t[{{$i}}], C) {{- end}} - t[{{sub $.all.NbWords 1}}], C = bits.Add64(t[{{$.all.NbWords}}], C, 0) - t[{{$.all.NbWords}}], _ = bits.Add64(0, D, C) + t[{{sub $.all.NbWords 1}}], C = bits.{{$.all.Word.Add}}(t[{{$.all.NbWords}}], C, 0) + t[{{$.all.NbWords}}], _ = bits.{{$.all.Word.Add}}(0, D, C) {{- end}} if t[{{$.all.NbWords}}] != 0 { // we need to reduce, we have a result on {{add 1 $.all.NbWords}} words {{- if gt $.all.NbWords 1}} - var b uint64 + var b {{$.all.Word.TypeLower}} {{- end}} - z[0], {{- if gt $.all.NbWords 1}}b{{- else}}_{{- end}} = bits.Sub64(t[0], q0, 0) + z[0], {{- if gt $.all.NbWords 1}}b{{- else}}_{{- end}} = bits.{{$.all.Word.Sub}}(t[0], q0, 0) {{- range $i := .all.NbWordsIndexesNoZero}} {{- if eq $i $.all.NbWordsLastIndex}} - z[{{$i}}], _ = bits.Sub64(t[{{$i}}], q{{$i}}, b) + z[{{$i}}], _ = bits.{{$.all.Word.Sub}}(t[{{$i}}], q{{$i}}, b) {{- else }} - z[{{$i}}], b = bits.Sub64(t[{{$i}}], q{{$i}}, b) + z[{{$i}}], b = bits.{{$.all.Word.Sub}}(t[{{$i}}], q{{$i}}, b) {{- end}} {{- end}} return {{if $.ReturnZ }} z{{- end}} @@ -110,14 +110,14 @@ const MulCIOS = ` // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. - var r uint64 - hi, lo := bits.Mul64({{$.V1}}[0], {{$.V2}}[0]) + var r {{$.all.Word.TypeLower}} + hi, lo := bits.{{$.all.Word.Mul}}({{$.V1}}[0], {{$.V2}}[0]) if lo != 0 { hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow } m := lo * qInvNeg - hi2, _ := bits.Mul64(m, q) - r, carry := bits.Add64(hi2, hi, 0) + hi2, _ := bits.{{$.all.Word.Mul}}(m, q) + r, carry := bits.{{$.all.Word.Add}}(hi2, hi, 0) if carry != 0 || r >= q { // we need to reduce diff --git a/field/generator/internal/templates/element/tests.go b/field/generator/internal/templates/element/tests.go index 35d777e493..6d2baef32b 100644 --- a/field/generator/internal/templates/element/tests.go +++ b/field/generator/internal/templates/element/tests.go @@ -1576,9 +1576,9 @@ func gen() gopter.Gen { g.element = {{.ElementName}}{ {{- range $i := .NbWordsIndexesFull}} - genParams.NextUint64(),{{end}} + {{$.Word.TypeLower}}(genParams.NextUint64()),{{end}} } - if qElement[{{.NbWordsLastIndex}}] != ^uint64(0) { + if qElement[{{.NbWordsLastIndex}}] != ^{{$.Word.TypeLower}}(0) { g.element[{{.NbWordsLastIndex}}] %= (qElement[{{.NbWordsLastIndex}}] +1 ) } @@ -1586,9 +1586,9 @@ func gen() gopter.Gen { for !g.element.smallerThanModulus() { g.element = {{.ElementName}}{ {{- range $i := .NbWordsIndexesFull}} - genParams.NextUint64(),{{end}} + {{$.Word.TypeLower}}(genParams.NextUint64()),{{end}} } - if qElement[{{.NbWordsLastIndex}}] != ^uint64(0) { + if qElement[{{.NbWordsLastIndex}}] != ^{{$.Word.TypeLower}}(0) { g.element[{{.NbWordsLastIndex}}] %= (qElement[{{.NbWordsLastIndex}}] +1 ) } } @@ -1604,19 +1604,19 @@ func genRandomFq(genParams *gopter.GenParameters) {{.ElementName}} { g = {{.ElementName}}{ {{- range $i := .NbWordsIndexesFull}} - genParams.NextUint64(),{{end}} + {{$.Word.TypeLower}}(genParams.NextUint64()),{{end}} } - if qElement[{{.NbWordsLastIndex}}] != ^uint64(0) { + if qElement[{{.NbWordsLastIndex}}] != ^{{$.Word.TypeLower}}(0) { g[{{.NbWordsLastIndex}}] %= (qElement[{{.NbWordsLastIndex}}] +1 ) } for !g.smallerThanModulus() { g = {{.ElementName}}{ {{- range $i := .NbWordsIndexesFull}} - genParams.NextUint64(),{{end}} + {{$.Word.TypeLower}}(genParams.NextUint64()),{{end}} } - if qElement[{{.NbWordsLastIndex}}] != ^uint64(0) { + if qElement[{{.NbWordsLastIndex}}] != ^{{$.Word.TypeLower}}(0) { g[{{.NbWordsLastIndex}}] %= (qElement[{{.NbWordsLastIndex}}] +1 ) } } @@ -1629,12 +1629,12 @@ func genFull() gopter.Gen { return func(genParams *gopter.GenParameters) *gopter.GenResult { a := genRandomFq(genParams) - var carry uint64 + var carry {{$.Word.TypeLower}} {{- range $i := .NbWordsIndexesFull}} {{- if eq $i $.NbWordsLastIndex}} - a[{{$i}}], _ = bits.Add64(a[{{$i}}], qElement[{{$i}}], carry) + a[{{$i}}], _ = bits.{{$.Word.Add}}(a[{{$i}}], qElement[{{$i}}], carry) {{- else}} - a[{{$i}}], carry = bits.Add64(a[{{$i}}], qElement[{{$i}}], carry) + a[{{$i}}], carry = bits.{{$.Word.Add}}(a[{{$i}}], qElement[{{$i}}], carry) {{- end}} {{- end}} diff --git a/field/generator/internal/templates/element/tests_vector.go b/field/generator/internal/templates/element/tests_vector.go index 41187a2caa..9bdadd353f 100644 --- a/field/generator/internal/templates/element/tests_vector.go +++ b/field/generator/internal/templates/element/tests_vector.go @@ -330,9 +330,9 @@ func genVector(size int) gopter.Gen { g := make(Vector, size) mixer := {{.ElementName}}{ {{- range $i := .NbWordsIndexesFull}} - genParams.NextUint64(),{{end}} + {{$.Word.TypeLower}}(genParams.NextUint64()),{{end}} } - if qElement[{{.NbWordsLastIndex}}] != ^uint64(0) { + if qElement[{{.NbWordsLastIndex}}] != ^{{$.Word.TypeLower}}(0) { mixer[{{.NbWordsLastIndex}}] %= (qElement[{{.NbWordsLastIndex}}] +1 ) } @@ -340,9 +340,9 @@ func genVector(size int) gopter.Gen { for !mixer.smallerThanModulus() { mixer = {{.ElementName}}{ {{- range $i := .NbWordsIndexesFull}} - genParams.NextUint64(),{{end}} + {{$.Word.TypeLower}}(genParams.NextUint64()),{{end}} } - if qElement[{{.NbWordsLastIndex}}] != ^uint64(0) { + if qElement[{{.NbWordsLastIndex}}] != ^{{$.Word.TypeLower}}(0) { mixer[{{.NbWordsLastIndex}}] %= (qElement[{{.NbWordsLastIndex}}] +1 ) } } diff --git a/field/generator/internal/templates/element/vector.go b/field/generator/internal/templates/element/vector.go index 447b429556..f3120eb73a 100644 --- a/field/generator/internal/templates/element/vector.go +++ b/field/generator/internal/templates/element/vector.go @@ -107,11 +107,11 @@ func (vector *Vector) AsyncReadFrom(r io.Reader) (int64, error, chan error) { bend := bstart + Bytes b := bSlice[bstart:bend] {{- range $i := reverse .NbWordsIndexesFull}} - {{- $j := mul $i 8}} + {{- $j := mul $i $.Word.ByteSize}} {{- $k := sub $.NbWords 1}} {{- $k := sub $k $i}} - {{- $jj := add $j 8}} - z[{{$k}}] = binary.BigEndian.Uint64(b[{{$j}}:{{$jj}}]) + {{- $jj := add $j $.Word.ByteSize}} + z[{{$k}}] = binary.BigEndian.{{$.Word.TypeUpper}}(b[{{$j}}:{{$jj}}]) {{- end}} if !z.smallerThanModulus() { diff --git a/field/koalabear/arith.go b/field/koalabear/arith.go index 0252a5bd1e..aab49151f6 100644 --- a/field/koalabear/arith.go +++ b/field/koalabear/arith.go @@ -21,40 +21,40 @@ import ( ) // madd0 hi = a*b + c (discards lo bits) -func madd0(a, b, c uint64) (hi uint64) { - var carry, lo uint64 - hi, lo = bits.Mul64(a, b) - _, carry = bits.Add64(lo, c, 0) - hi, _ = bits.Add64(hi, 0, carry) +func madd0(a, b, c uint32) (hi uint32) { + var carry, lo uint32 + hi, lo = bits.Mul32(a, b) + _, carry = bits.Add32(lo, c, 0) + hi, _ = bits.Add32(hi, 0, carry) return } // madd1 hi, lo = a*b + c -func madd1(a, b, c uint64) (hi uint64, lo uint64) { - var carry uint64 - hi, lo = bits.Mul64(a, b) - lo, carry = bits.Add64(lo, c, 0) - hi, _ = bits.Add64(hi, 0, carry) +func madd1(a, b, c uint32) (hi uint32, lo uint32) { + var carry uint32 + hi, lo = bits.Mul32(a, b) + lo, carry = bits.Add32(lo, c, 0) + hi, _ = bits.Add32(hi, 0, carry) return } // madd2 hi, lo = a*b + c + d -func madd2(a, b, c, d uint64) (hi uint64, lo uint64) { - var carry uint64 - hi, lo = bits.Mul64(a, b) - c, carry = bits.Add64(c, d, 0) - hi, _ = bits.Add64(hi, 0, carry) - lo, carry = bits.Add64(lo, c, 0) - hi, _ = bits.Add64(hi, 0, carry) +func madd2(a, b, c, d uint32) (hi uint32, lo uint32) { + var carry uint32 + hi, lo = bits.Mul32(a, b) + c, carry = bits.Add32(c, d, 0) + hi, _ = bits.Add32(hi, 0, carry) + lo, carry = bits.Add32(lo, c, 0) + hi, _ = bits.Add32(hi, 0, carry) return } -func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { - var carry uint64 - hi, lo = bits.Mul64(a, b) - c, carry = bits.Add64(c, d, 0) - hi, _ = bits.Add64(hi, 0, carry) - lo, carry = bits.Add64(lo, c, 0) - hi, _ = bits.Add64(hi, e, carry) +func madd3(a, b, c, d, e uint32) (hi uint32, lo uint32) { + var carry uint32 + hi, lo = bits.Mul32(a, b) + c, carry = bits.Add32(c, d, 0) + hi, _ = bits.Add32(hi, 0, carry) + lo, carry = bits.Add32(lo, c, 0) + hi, _ = bits.Add32(hi, e, carry) return } diff --git a/field/koalabear/element.go b/field/koalabear/element.go index fca7f64f55..7e42842839 100644 --- a/field/koalabear/element.go +++ b/field/koalabear/element.go @@ -19,22 +19,20 @@ package koalabear import ( "crypto/rand" "encoding/binary" + "errors" "io" "math/big" "math/bits" + "reflect" + "strconv" + "strings" + "github.com/bits-and-blooms/bitset" + "github.com/consensys/gnark-crypto/field/hash" "github.com/consensys/gnark-crypto/field/pool" ) -// rSquare where r is the Montgommery constant -// see section 2.3.2 of Tolga Acar's thesis -// https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf -var rSquare = Element{ - 1111325836, -} - - -// Element represents a field element stored on 1 words (uint64) +// Element represents a field element stored on 1 words (uint32) // // Element are assumed to be in Montgomery form in all methods. // @@ -46,18 +44,18 @@ var rSquare = Element{ // # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. -type Element [1]uint64 +type Element [1]uint32 const ( - Limbs = 1 // number of 64 bits words needed to represent a Element + Limbs = 1 // number of uint32 words needed to represent a Element Bits = 31 // number of bits needed to represent a Element - Bytes = 8 // number of bytes needed to represent a Element + Bytes = 4 // number of bytes needed to represent a Element ) // Field modulus q const ( - q0 uint64 = 2130706433 - q uint64 = q0 + q0 uint32 = 2130706433 + q uint32 = q0 ) var qElement = Element{ @@ -76,7 +74,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 13906834176474087423 +const qInvNeg uint32 = 2130706431 func init() { _modulus.SetString("7f000001", 16) @@ -89,7 +87,7 @@ func init() { // var v Element // v.SetUint64(...) func NewElement(v uint64) Element { - z := Element{v} + z := Element{uint32(v % uint64(q0))} z.Mul(&z, &rSquare) return z } @@ -97,8 +95,9 @@ func NewElement(v uint64) Element { // SetUint64 sets z to v and returns z func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form - *z = Element{v} + *z = Element{uint32(v % uint64(q0))} return z.Mul(z, &rSquare) // z.toMont() + } // SetInt64 sets z to v and returns z @@ -122,6 +121,66 @@ func (z *Element) Set(x *Element) *Element { return z } +// SetInterface converts provided interface into Element +// returns an error if provided type is not supported +// supported types: +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte +func (z *Element) SetInterface(i1 interface{}) (*Element, error) { + if i1 == nil { + return nil, errors.New("can't set koalabear.Element with ") + } + + switch c1 := i1.(type) { + case Element: + return z.Set(&c1), nil + case *Element: + if c1 == nil { + return nil, errors.New("can't set koalabear.Element with ") + } + return z.Set(c1), nil + case uint8: + return z.SetUint64(uint64(c1)), nil + case uint16: + return z.SetUint64(uint64(c1)), nil + case uint32: + return z.SetUint64(uint64(c1)), nil + case uint: + return z.SetUint64(uint64(c1)), nil + case uint64: + return z.SetUint64(c1), nil + case int8: + return z.SetInt64(int64(c1)), nil + case int16: + return z.SetInt64(int64(c1)), nil + case int32: + return z.SetInt64(int64(c1)), nil + case int64: + return z.SetInt64(c1), nil + case int: + return z.SetInt64(int64(c1)), nil + case string: + return z.SetString(c1) + case *big.Int: + if c1 == nil { + return nil, errors.New("can't set koalabear.Element with ") + } + return z.SetBigInt(c1), nil + case big.Int: + return z.SetBigInt(&c1), nil + case []byte: + return z.SetBytes(c1), nil + default: + return nil, errors.New("can't set koalabear.Element from type " + reflect.TypeOf(i1).String()) + } +} // SetZero z = 0 func (z *Element) SetZero() *Element { @@ -131,10 +190,17 @@ func (z *Element) SetZero() *Element { // SetOne z = 1 (in Montgomery form) func (z *Element) SetOne() *Element { - z[0] = 402124772 + z[0] = 33554430 return z } +// Div z = x*y⁻¹ (mod q) +func (z *Element) Div(x, y *Element) *Element { + var yInv Element + yInv.Inverse(y) + z.Mul(x, &yInv) + return z +} // Equal returns z == x; constant-time func (z *Element) Equal(x *Element) bool { @@ -143,7 +209,7 @@ func (z *Element) Equal(x *Element) bool { // NotEqual returns 0 if and only if z == x; constant-time func (z *Element) NotEqual(x *Element) uint64 { - return (z[0] ^ x[0]) + return uint64((z[0] ^ x[0])) } // IsZero returns z == 0 @@ -153,9 +219,56 @@ func (z *Element) IsZero() bool { // IsOne returns z == 1 func (z *Element) IsOne() bool { - return z[0] == 402124772 + return z[0] == 33554430 +} + +// IsUint64 reports whether z can be represented as an uint64. +func (z *Element) IsUint64() bool { + return true +} + +// Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. +func (z *Element) Uint64() uint64 { + return uint64(z.Bits()[0]) +} + +// FitsOnOneWord reports whether z words (except the least significant word) are 0 +// +// It is the responsibility of the caller to convert from Montgomery to Regular form if needed. +func (z *Element) FitsOnOneWord() bool { + return true +} + +// Cmp compares (lexicographic order) z and x and returns: +// +// -1 if z < x +// 0 if z == x +// +1 if z > x +func (z *Element) Cmp(x *Element) int { + _z := z.Bits() + _x := x.Bits() + if _z[0] > _x[0] { + return 1 + } else if _z[0] < _x[0] { + return -1 + } + return 0 } +// LexicographicallyLargest returns true if this element is strictly lexicographically +// larger than its negation, false otherwise +func (z *Element) LexicographicallyLargest() bool { + // adapted from github.com/zkcrypto/bls12_381 + // we check if the element is larger than (q-1) / 2 + // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 + + _z := z.Bits() + + var b uint32 + _, b = bits.Sub32(_z[0], 1065353217, 0) + + return b == 0 +} // SetRandom sets z to a uniform random value in [0, q). // @@ -187,11 +300,11 @@ func (z *Element) SetRandom() (*Element, error) { if _, err := io.ReadFull(rand.Reader, bytes[:k]); err != nil { return nil, err } - + // TODO @gbotrel check this is correct with 32bits words // Clear unused bits in in the most significant byte to increase probability // that the candidate is < q. bytes[k-1] &= uint8(int(1<> 1 @@ -237,7 +350,7 @@ func (z *Element) fromMont() *Element { // Add z = x + y (mod q) func (z *Element) Add(x, y *Element) *Element { - z[0], _ = bits.Add64(x[0], y[0], 0) + z[0], _ = bits.Add32(x[0], y[0], 0) if z[0] >= q { z[0] -= q } @@ -246,23 +359,17 @@ func (z *Element) Add(x, y *Element) *Element { // Double z = x + x (mod q), aka Lsh 1 func (z *Element) Double(x *Element) *Element { - if x[0]&(1<<63) == (1 << 63) { - // if highest bit is set, then we have a carry to x + x, we shift and subtract q - z[0] = (x[0] << 1) - q - } else { - // highest bit is not set, but x + x can still be >= q - z[0] = (x[0] << 1) - if z[0] >= q { - z[0] -= q - } + z[0] = x[0] << 1 + if z[0] >= q { + z[0] -= q } return z } // Sub z = x - y (mod q) func (z *Element) Sub(x, y *Element) *Element { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) + var b uint32 + z[0], b = bits.Sub32(x[0], y[0], 0) if b != 0 { z[0] += q } @@ -279,6 +386,56 @@ func (z *Element) Neg(x *Element) *Element { return z } +// Select is a constant-time conditional move. +// If c=0, z = x0. Else z = x1 +func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { + cC := uint32((int64(c) | -int64(c)) >> 63) // "canonicized" into: 0 if c=0, -1 otherwise + z[0] = x0[0] ^ cC&(x0[0]^x1[0]) + return z +} + +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. +func _mulGeneric(z, x, y *Element) { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + var t [2]uint32 + var D uint32 + var m, C uint32 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul32(y[0], x[0]) + + t[1], D = bits.Add32(t[1], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + + t[0], C = bits.Add32(t[1], C, 0) + t[1], _ = bits.Add32(0, D, C) + + if t[1] != 0 { + // we need to reduce, we have a result on 2 words + z[0], _ = bits.Sub32(t[0], q0, 0) + return + } + + // copy t into z + z[0] = t[0] + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + z[0] -= q + } +} func _fromMontGeneric(z *Element) { // the following lines implement z = z * 1 @@ -305,6 +462,117 @@ func _reduceGeneric(z *Element) { } } +// BatchInvert returns a new slice with every element inverted. +// Uses Montgomery batch inversion trick +func BatchInvert(a []Element) []Element { + res := make([]Element, len(a)) + if len(a) == 0 { + return res + } + + zeroes := bitset.New(uint(len(a))) + accumulator := One() + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + zeroes.Set(uint(i)) + continue + } + res[i] = accumulator + accumulator.Mul(&accumulator, &a[i]) + } + + accumulator.Inverse(&accumulator) + + for i := len(a) - 1; i >= 0; i-- { + if zeroes.Test(uint(i)) { + continue + } + res[i].Mul(&res[i], &accumulator) + accumulator.Mul(&accumulator, &a[i]) + } + + return res +} + +func _butterflyGeneric(a, b *Element) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// BitLen returns the minimum number of bits needed to represent z +// returns 0 if z == 0 +func (z *Element) BitLen() int { + return bits.Len32(z[0]) +} + +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := hash.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := pool.BigInt.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + pool.BigInt.Put(vv) + + return res, nil +} + +// Exp z = xᵏ (mod q) +func (z *Element) Exp(x Element, k *big.Int) *Element { + if k.IsUint64() && k.Uint64() == 0 { + return z.SetOne() + } + + e := k + if k.Sign() == -1 { + // negative k, we invert + // if k < 0: xᵏ (mod q) == (x⁻¹)ᵏ (mod q) + x.Inverse(&x) + + // we negate k in a temp big.Int since + // Int.Bit(_) of k and -k is different + e = pool.BigInt.Get() + defer pool.BigInt.Put(e) + e.Neg(k) + } + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z +} + +// rSquare where r is the Montgommery constant +// see section 2.3.2 of Tolga Acar's thesis +// https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf +var rSquare = Element{ + 402124772, +} // toMont converts z to Montgomery form // sets and returns z = z * r² @@ -312,15 +580,47 @@ func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } +// String returns the decimal representation of z as generated by +// z.Text(10). +func (z *Element) String() string { + return z.Text(10) +} // toBigInt returns z as a big.Int in Montgomery form func (z *Element) toBigInt(res *big.Int) *big.Int { var b [Bytes]byte - binary.BigEndian.PutUint64(b[0:8], z[0]) + binary.BigEndian.PutUint32(b[0:4], z[0]) return res.SetBytes(b[:]) } +// Text returns the string representation of z in the given base. +// Base must be between 2 and 36, inclusive. The result uses the +// lower-case letters 'a' to 'z' for digit values 10 to 35. +// No prefix (such as "0x") is added to the string. If z is a nil +// pointer it returns "". +// If base == 10 and -z fits in a uint16 prefix "-" is added to the string. +func (z *Element) Text(base int) string { + if base < 2 || base > 36 { + panic("invalid base") + } + if z == nil { + return "" + } + + const maxUint16 = 65535 + if base == 10 { + var zzNeg Element + zzNeg.Neg(z) + zzNeg.fromMont() + if zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { + return "-" + strconv.FormatUint(uint64(zzNeg[0]), base) + } + } + zz := z.Bits() + return strconv.FormatUint(uint64(zz[0]), base) +} + // BigInt sets and return z as a *big.Int func (z *Element) BigInt(res *big.Int) *big.Int { _z := *z @@ -336,6 +636,72 @@ func (z Element) ToBigIntRegular(res *big.Int) *big.Int { return z.toBigInt(res) } +// Bits provides access to z by returning its value as a little-endian [1]uint32 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [1]uint32 { + _z := *z + fromMont(&_z) + return _z +} + +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) + return +} + +// Marshal returns the value of z as a big-endian byte slice +func (z *Element) Marshal() []byte { + b := z.Bytes() + return b[:] +} + +// Unmarshal is an alias for SetBytes, it sets z to the value of e. +func (z *Element) Unmarshal(e []byte) { + z.SetBytes(e) +} + +// SetBytes interprets e as the bytes of a big-endian unsigned integer, +// sets z to that value, and returns z. +func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. + // get a big int from our pool + vv := pool.BigInt.Get() + vv.SetBytes(e) + + // set big int + z.SetBigInt(vv) + + // put temporary object back in pool + pool.BigInt.Put(vv) + + return z +} + +// SetBytesCanonical interprets e as the bytes of a big-endian 4-byte integer. +// If e is not a 4-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid koalabear.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { @@ -373,17 +739,248 @@ func (z *Element) setBigInt(v *big.Int) *Element { if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { - z[i] = uint64(vBits[i]) + z[i] = uint32(vBits[i]) } } else { for i := 0; i < len(vBits); i++ { if i%2 == 0 { - z[i/2] = uint64(vBits[i]) + z[i/2] = uint32(vBits[i]) } else { - z[i/2] |= uint64(vBits[i]) << 32 + z[i/2] |= uint32(vBits[i]) << 32 } } } return z.toMont() } + +// SetString creates a big.Int with number and calls SetBigInt on z +// +// The number prefix determines the actual base: A prefix of +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 +// and no prefix is accepted. +// +// For base 16, lower and upper case letters are considered the same: +// The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. +// +// An underscore character ”_” may appear between a base +// prefix and an adjacent digit, and between successive digits; such +// underscores do not change the value of the number. +// Incorrect placement of underscores is reported as a panic if there +// are no other errors. +// +// If the number is invalid this method leaves z unchanged and returns nil, error. +func (z *Element) SetString(number string) (*Element, error) { + // get temporary big int from the pool + vv := pool.BigInt.Get() + + if _, ok := vv.SetString(number, 0); !ok { + return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) + } + + z.SetBigInt(vv) + + // release object into pool + pool.BigInt.Put(vv) + + return z, nil +} + +// MarshalJSON returns json encoding of z (z.Text(10)) +// If z == nil, returns null +func (z *Element) MarshalJSON() ([]byte, error) { + if z == nil { + return []byte("null"), nil + } + const maxSafeBound = 15 // we encode it as number if it's small + s := z.Text(10) + if len(s) <= maxSafeBound { + return []byte(s), nil + } + var sbb strings.Builder + sbb.WriteByte('"') + sbb.WriteString(s) + sbb.WriteByte('"') + return []byte(sbb.String()), nil +} + +// UnmarshalJSON accepts numbers and strings as input +// See Element.SetString for valid prefixes (0x, 0b, ...) +func (z *Element) UnmarshalJSON(data []byte) error { + s := string(data) + if len(s) > Bits*3 { + return errors.New("value too large (max = Element.Bits * 3)") + } + + // we accept numbers and strings, remove leading and trailing quotes if any + if len(s) > 0 && s[0] == '"' { + s = s[1:] + } + if len(s) > 0 && s[len(s)-1] == '"' { + s = s[:len(s)-1] + } + + // get temporary big int from the pool + vv := pool.BigInt.Get() + + if _, ok := vv.SetString(s, 0); !ok { + return errors.New("can't parse into a big.Int: " + s) + } + + z.SetBigInt(vv) + + // release object into pool + pool.BigInt.Put(vv) + return nil +} + +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 4-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint32((*b)[0:4]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid koalabear.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint32((*b)[0:4], e[0]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint32((*b)[0:4]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid koalabear.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint32((*b)[0:4], e[0]) +} + +func (littleEndian) String() string { return "LittleEndian" } + +// Legendre returns the Legendre symbol of z (either +1, -1, or 0.) +func (z *Element) Legendre() int { + var l Element + // z^((q-1)/2) + l.expByLegendreExp(*z) + + if l.IsZero() { + return 0 + } + + // if l == 1 + if l.IsOne() { + return 1 + } + return -1 +} + +// Sqrt z = √x (mod q) +// if the square root doesn't exist (x is not a square mod q) +// Sqrt leaves z unchanged and returns nil +func (z *Element) Sqrt(x *Element) *Element { + // q ≡ 1 (mod 4) + // see modSqrtTonelliShanks in math/big/int.go + // using https://www.maa.org/sites/default/files/pdf/upload_library/22/Polya/07468342.di020786.02p0470a.pdf + + var y, b, t, w Element + // w = x^((s-1)/2)) + w.expBySqrtExp(*x) + + // y = x^((s+1)/2)) = w * x + y.Mul(x, &w) + + // b = xˢ = w * w * x = y * x + b.Mul(&w, &y) + + // g = nonResidue ^ s + var g = Element{ + 331895189, + } + r := uint64(24) + + // compute legendre symbol + // t = x^((q-1)/2) = r-1 squaring of xˢ + t = b + for i := uint64(0); i < r-1; i++ { + t.Square(&t) + } + if t.IsZero() { + return z.SetZero() + } + if !t.IsOne() { + // t != 1, we don't have a square root + return nil + } + for { + var m uint64 + t = b + + // for t != 1 + for !t.IsOne() { + t.Square(&t) + m++ + } + + if m == 0 { + return z.Set(&y) + } + // t = g^(2^(r-m-1)) (mod q) + ge := int(r - m - 1) + t = g + for ge > 0 { + t.Square(&t) + ge-- + } + + g.Square(&t) + y.Mul(&y, &t) + b.Mul(&b, &g) + r = m + } +} + +// Inverse z = x⁻¹ (mod q) +// +// note: allocates a big.Int (math/big) +func (z *Element) Inverse(x *Element) *Element { + var _xNonMont big.Int + x.BigInt(&_xNonMont) + _xNonMont.ModInverse(&_xNonMont, Modulus()) + z.SetBigInt(&_xNonMont) + return z +} diff --git a/field/koalabear/element_purego.go b/field/koalabear/element_purego.go index 21b1eb5493..ea28aa3cd5 100644 --- a/field/koalabear/element_purego.go +++ b/field/koalabear/element_purego.go @@ -18,7 +18,26 @@ package koalabear import "math/bits" +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + var y Element + y.SetUint64(3) + x.Mul(x, &y) +} +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + var y Element + y.SetUint64(5) + x.Mul(x, &y) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y Element + y.SetUint64(13) + x.Mul(x, &y) +} func fromMont(z *Element) { _fromMontGeneric(z) @@ -45,14 +64,14 @@ func (z *Element) Mul(x, y *Element) *Element { // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. - var r uint64 - hi, lo := bits.Mul64(x[0], y[0]) + var r uint32 + hi, lo := bits.Mul32(x[0], y[0]) if lo != 0 { hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow } m := lo * qInvNeg - hi2, _ := bits.Mul64(m, q) - r, carry := bits.Add64(hi2, hi, 0) + hi2, _ := bits.Mul32(m, q) + r, carry := bits.Add32(hi2, hi, 0) if carry != 0 || r >= q { // we need to reduce @@ -62,3 +81,47 @@ func (z *Element) Mul(x, y *Element) *Element { return z } + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + // In fact, since the modulus R fits on one register, the CIOS algorithm gets reduced to standard REDC (textbook Montgomery reduction): + // hi, lo := x * y + // m := (lo * qInvNeg) mod R + // (*) r := (hi * R + lo + m * q) / R + // reduce r if necessary + + // On the emphasized line, we get r = hi + (lo + m * q) / R + // If we write hi2, lo2 = m * q then R | m * q - lo2 ⇒ R | (lo * qInvNeg) q - lo2 = -lo - lo2 + // This shows lo + lo2 = 0 mod R. i.e. lo + lo2 = 0 if lo = 0 and R otherwise. + // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) + // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. + + var r uint32 + hi, lo := bits.Mul32(x[0], x[0]) + if lo != 0 { + hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow + } + m := lo * qInvNeg + hi2, _ := bits.Mul32(m, q) + r, carry := bits.Add32(hi2, hi, 0) + + if carry != 0 || r >= q { + // we need to reduce + r -= q + } + z[0] = r + + return z +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/field/koalabear/element_test.go b/field/koalabear/element_test.go index bacc70706c..13899466b6 100644 --- a/field/koalabear/element_test.go +++ b/field/koalabear/element_test.go @@ -17,12 +17,16 @@ package koalabear import ( + "crypto/rand" + "encoding/json" + "fmt" "math/big" "math/bits" "testing" "github.com/leanovate/gopter" + ggen "github.com/leanovate/gopter/gen" "github.com/leanovate/gopter/prop" "github.com/stretchr/testify/require" @@ -35,6 +39,95 @@ import ( var benchResElement Element +func BenchmarkElementSelect(b *testing.B) { + var x, y Element + x.SetRandom() + y.SetRandom() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Select(i%3, &x, &y) + } +} + +func BenchmarkElementSetRandom(b *testing.B) { + var x Element + x.SetRandom() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = x.SetRandom() + } +} + +func BenchmarkElementSetBytes(b *testing.B) { + var x Element + x.SetRandom() + bb := x.Bytes() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchResElement.SetBytes(bb[:]) + } + +} + +func BenchmarkElementMulByConstants(b *testing.B) { + b.Run("mulBy3", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy3(&benchResElement) + } + }) + b.Run("mulBy5", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy5(&benchResElement) + } + }) + b.Run("mulBy13", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy13(&benchResElement) + } + }) +} + +func BenchmarkElementInverse(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchResElement.Inverse(&x) + } + +} + +func BenchmarkElementButterfly(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + Butterfly(&x, &benchResElement) + } +} + +func BenchmarkElementExp(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b1, _ := rand.Int(rand.Reader, Modulus()) + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Exp(x, b1) + } +} func BenchmarkElementDouble(b *testing.B) { benchResElement.SetRandom() @@ -72,6 +165,16 @@ func BenchmarkElementNeg(b *testing.B) { } } +func BenchmarkElementDiv(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Div(&x, &benchResElement) + } +} + func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -80,12 +183,27 @@ func BenchmarkElementFromMont(b *testing.B) { } } +func BenchmarkElementSquare(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Square(&benchResElement) + } +} - +func BenchmarkElementSqrt(b *testing.B) { + var a Element + a.SetUint64(4) + a.Neg(&a) + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Sqrt(&a) + } +} func BenchmarkElementMul(b *testing.B) { x := Element{ - 1111325836, + 402124772, } benchResElement.SetOne() b.ResetTimer() @@ -94,6 +212,48 @@ func BenchmarkElementMul(b *testing.B) { } } +func BenchmarkElementCmp(b *testing.B) { + x := Element{ + 402124772, + } + benchResElement = x + benchResElement[0] = 0 + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Cmp(&x) + } +} + +func TestElementCmp(t *testing.T) { + var x, y Element + + if x.Cmp(&y) != 0 { + t.Fatal("x == y") + } + + one := One() + y.Sub(&y, &one) + + if x.Cmp(&y) != -1 { + t.Fatal("x < y") + } + if y.Cmp(&x) != 1 { + t.Fatal("x < y") + } + + x = y + if x.Cmp(&y) != 0 { + t.Fatal("x == y") + } + + x.Sub(&x, &one) + if x.Cmp(&y) != -1 { + t.Fatal("x < y") + } + if y.Cmp(&x) != 1 { + t.Fatal("x < y") + } +} func TestElementNegZero(t *testing.T) { var a, b Element @@ -226,8 +386,7 @@ func TestElementEqual(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } - -func TestElementAdd(t *testing.T) { +func TestElementBytes(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { @@ -239,106 +398,146 @@ func TestElementAdd(t *testing.T) { properties := gopter.NewProperties(parameters) genA := gen() - genB := gen() - - properties.Property("Add: having the receiver as operand should output the same result", prop.ForAll( - func(a, b testPairElement) bool { - var c, d Element - d.Set(&a.element) - - c.Add(&a.element, &b.element) - a.element.Add(&a.element, &b.element) - b.element.Add(&d, &b.element) - return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + properties.Property("SetBytes(Bytes()) should stay constant", prop.ForAll( + func(a testPairElement) bool { + var b Element + bytes := a.element.Bytes() + b.SetBytes(bytes[:]) + return a.element.Equal(&b) }, genA, - genB, )) - properties.Property("Add: operation result must match big.Int result", prop.ForAll( - func(a, b testPairElement) bool { - { - var c Element + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} - c.Add(&a.element, &b.element) +func TestElementInverseExp(t *testing.T) { + // inverse must be equal to exp^-2 + exp := Modulus() + exp.Sub(exp, new(big.Int).SetUint64(2)) - var d, e big.Int - d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) + invMatchExp := func(a testPairElement) bool { + var b Element + b.Set(&a.element) + a.element.Inverse(&a.element) + b.Exp(b, exp) - if c.BigInt(&e).Cmp(&d) != 0 { - return false - } - } + return a.element.Equal(&b) + } - // fixed elements - // a is random - // r takes special values - testValues := make([]Element, len(staticTestValues)) - copy(testValues, staticTestValues) + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + properties := gopter.NewProperties(parameters) + genA := gen() + properties.Property("inv == exp^-2", prop.ForAll(invMatchExp, genA)) + properties.TestingRun(t, gopter.ConsoleReporter(false)) - for i := range testValues { - r := testValues[i] - var d, e, rb big.Int - r.BigInt(&rb) + parameters.MinSuccessfulTests = 1 + properties = gopter.NewProperties(parameters) + properties.Property("inv(0) == 0", prop.ForAll(invMatchExp, ggen.OneConstOf(testPairElement{}))) + properties.TestingRun(t, gopter.ConsoleReporter(false)) - var c Element - c.Add(&a.element, &r) - d.Add(&a.bigint, &rb).Mod(&d, Modulus()) +} - if c.BigInt(&e).Cmp(&d) != 0 { +func mulByConstant(z *Element, c uint8) { + var y Element + y.SetUint64(uint64(c)) + z.Mul(z, &y) +} + +func TestElementMulByConstants(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + implemented := []uint8{0, 1, 2, 3, 5, 13} + properties.Property("mulByConstant", prop.ForAll( + func(a testPairElement) bool { + for _, c := range implemented { + var constant Element + constant.SetUint64(uint64(c)) + + b := a.element + b.Mul(&b, &constant) + + aa := a.element + mulByConstant(&aa, c) + + if !aa.Equal(&b) { return false } } + return true }, genA, - genB, )) - properties.Property("Add: operation result must be smaller than modulus", prop.ForAll( - func(a, b testPairElement) bool { - var c Element + properties.Property("MulBy3(x) == Mul(x, 3)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(3) - c.Add(&a.element, &b.element) + b := a.element + b.Mul(&b, &constant) - return c.smallerThanModulus() + MulBy3(&a.element) + + return a.element.Equal(&b) }, genA, - genB, )) - specialValueTest := func() { - // test special values against special values - testValues := make([]Element, len(staticTestValues)) - copy(testValues, staticTestValues) + properties.Property("MulBy5(x) == Mul(x, 5)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(5) - for i := range testValues { - a := testValues[i] - var aBig big.Int - a.BigInt(&aBig) - for j := range testValues { - b := testValues[j] - var bBig, d, e big.Int - b.BigInt(&bBig) + b := a.element + b.Mul(&b, &constant) - var c Element - c.Add(&a, &b) - d.Add(&aBig, &bBig).Mod(&d, Modulus()) + MulBy5(&a.element) - if c.BigInt(&e).Cmp(&d) != 0 { - t.Fatal("Add failed special test values") - } - } - } - } + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("MulBy13(x) == Mul(x, 13)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(13) + + b := a.element + b.Mul(&b, &constant) + + MulBy13(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() } -func TestElementSub(t *testing.T) { +func TestElementLegendre(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { @@ -350,43 +549,268 @@ func TestElementSub(t *testing.T) { properties := gopter.NewProperties(parameters) genA := gen() - genB := gen() - - properties.Property("Sub: having the receiver as operand should output the same result", prop.ForAll( - func(a, b testPairElement) bool { - var c, d Element - d.Set(&a.element) - - c.Sub(&a.element, &b.element) - a.element.Sub(&a.element, &b.element) - b.element.Sub(&d, &b.element) - return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + properties.Property("legendre should output same result than big.Int.Jacobi", prop.ForAll( + func(a testPairElement) bool { + return a.element.Legendre() == big.Jacobi(&a.bigint, Modulus()) }, genA, - genB, )) - properties.Property("Sub: operation result must match big.Int result", prop.ForAll( - func(a, b testPairElement) bool { - { - var c Element + properties.TestingRun(t, gopter.ConsoleReporter(false)) - c.Sub(&a.element, &b.element) +} - var d, e big.Int - d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) +func TestElementBitLen(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } - if c.BigInt(&e).Cmp(&d) != 0 { - return false - } - } + properties := gopter.NewProperties(parameters) - // fixed elements - // a is random - // r takes special values - testValues := make([]Element, len(staticTestValues)) - copy(testValues, staticTestValues) + genA := gen() + + properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( + func(a testPairElement) bool { + return a.element.fromMont().BitLen() == a.bigint.BitLen() + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementButterflies(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("butterfly0 == a -b; a +b", prop.ForAll( + func(a, b testPairElement) bool { + a0, b0 := a.element, b.element + + _butterflyGeneric(&a.element, &b.element) + Butterfly(&a0, &b0) + + return a.element.Equal(&a0) && b.element.Equal(&b0) + }, + genA, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementLexicographicallyLargest(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("element.Cmp should match LexicographicallyLargest output", prop.ForAll( + func(a testPairElement) bool { + var negA Element + negA.Neg(&a.element) + + cmpResult := a.element.Cmp(&negA) + lResult := a.element.LexicographicallyLargest() + + if lResult && cmpResult == 1 { + return true + } + if !lResult && cmpResult != 1 { + return true + } + return false + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementAdd(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Add: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Add(&a.element, &b.element) + a.element.Add(&a.element, &b.element) + b.element.Add(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Add: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Add(&a.element, &b.element) + + var d, e big.Int + d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Add(&a.element, &r) + d.Add(&a.bigint, &rb).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Add: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Add(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Add(&a, &b) + d.Add(&aBig, &bBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Add failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSub(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Sub: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Sub(&a.element, &b.element) + a.element.Sub(&a.element, &b.element) + b.element.Sub(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Sub: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Sub(&a.element, &b.element) + + var d, e big.Int + d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) for i := range testValues { r := testValues[i] @@ -509,6 +933,12 @@ func TestElementMul(t *testing.T) { d.Mul(&a.bigint, &rb).Mod(&d, Modulus()) // checking generic impl against asm path + var cGeneric Element + _mulGeneric(&cGeneric, &a.element, &r) + if !cGeneric.Equal(&c) { + // need to give context to failing error. + return false + } if c.BigInt(&e).Cmp(&d) != 0 { return false @@ -532,6 +962,16 @@ func TestElementMul(t *testing.T) { genB, )) + properties.Property("Mul: assembly implementation must be consistent with generic one", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + c.Mul(&a.element, &b.element) + _mulGeneric(&d, &a.element, &b.element) + return c.Equal(&d) + }, + genA, + genB, + )) specialValueTest := func() { // test special values against special values @@ -552,6 +992,12 @@ func TestElementMul(t *testing.T) { d.Mul(&aBig, &bBig).Mod(&d, Modulus()) // checking asm against generic impl + var cGeneric Element + _mulGeneric(&cGeneric, &a, &b) + if !cGeneric.Equal(&c) { + t.Fatal("Mul failed special test values: asm and generic impl don't match") + } + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } @@ -564,8 +1010,7 @@ func TestElementMul(t *testing.T) { } - -func TestElementDouble(t *testing.T) { +func TestElementDiv(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { @@ -577,43 +1022,79 @@ func TestElementDouble(t *testing.T) { properties := gopter.NewProperties(parameters) genA := gen() + genB := gen() - properties.Property("Double: having the receiver as operand should output the same result", prop.ForAll( - func(a testPairElement) bool { + properties.Property("Div: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) - var b Element + c.Div(&a.element, &b.element) + a.element.Div(&a.element, &b.element) + b.element.Div(&d, &b.element) - b.Double(&a.element) - a.element.Double(&a.element) - return a.element.Equal(&b) + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) }, genA, + genB, )) - properties.Property("Double: operation result must match big.Int result", prop.ForAll( - func(a testPairElement) bool { - var c Element - c.Double(&a.element) + properties.Property("Div: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element - var d, e big.Int - d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) + c.Div(&a.element, &b.element) - return c.BigInt(&e).Cmp(&d) == 0 + var d, e big.Int + d.ModInverse(&b.bigint, Modulus()) + d.Mul(&d, &a.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Div(&a.element, &r) + d.ModInverse(&rb, Modulus()) + d.Mul(&d, &a.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true }, genA, + genB, )) - properties.Property("Double: operation result must be smaller than modulus", prop.ForAll( - func(a testPairElement) bool { + properties.Property("Div: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { var c Element - c.Double(&a.element) + + c.Div(&a.element, &b.element) + return c.smallerThanModulus() }, genA, + genB, )) specialValueTest := func() { - // test special values + // test special values against special values testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) @@ -621,14 +1102,130 @@ func TestElementDouble(t *testing.T) { a := testValues[i] var aBig big.Int a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Div(&a, &b) + d.ModInverse(&bBig, Modulus()) + d.Mul(&d, &aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Div failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementExp(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Exp: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Exp(a.element, &b.bigint) + a.element.Exp(a.element, &b.bigint) + b.element.Exp(d, &b.bigint) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Exp: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Exp(a.element, &b.bigint) + + var d, e big.Int + d.Exp(&a.bigint, &b.bigint, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Exp(a.element, &rb) + d.Exp(&a.bigint, &rb, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Exp: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { var c Element - c.Double(&a) - var d, e big.Int - d.Lsh(&aBig, 1).Mod(&d, Modulus()) + c.Exp(a.element, &b.bigint) - if c.BigInt(&e).Cmp(&d) != 0 { - t.Fatal("Double failed special test values") + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Exp(a, &bBig) + d.Exp(&aBig, &bBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Exp failed special test values") + } } } } @@ -638,7 +1235,7 @@ func TestElementDouble(t *testing.T) { } -func TestElementNeg(t *testing.T) { +func TestElementSquare(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { @@ -651,35 +1248,35 @@ func TestElementNeg(t *testing.T) { genA := gen() - properties.Property("Neg: having the receiver as operand should output the same result", prop.ForAll( + properties.Property("Square: having the receiver as operand should output the same result", prop.ForAll( func(a testPairElement) bool { var b Element - b.Neg(&a.element) - a.element.Neg(&a.element) + b.Square(&a.element) + a.element.Square(&a.element) return a.element.Equal(&b) }, genA, )) - properties.Property("Neg: operation result must match big.Int result", prop.ForAll( + properties.Property("Square: operation result must match big.Int result", prop.ForAll( func(a testPairElement) bool { var c Element - c.Neg(&a.element) + c.Square(&a.element) var d, e big.Int - d.Neg(&a.bigint).Mod(&d, Modulus()) + d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) - properties.Property("Neg: operation result must be smaller than modulus", prop.ForAll( + properties.Property("Square: operation result must be smaller than modulus", prop.ForAll( func(a testPairElement) bool { var c Element - c.Neg(&a.element) + c.Square(&a.element) return c.smallerThanModulus() }, genA, @@ -695,13 +1292,13 @@ func TestElementNeg(t *testing.T) { var aBig big.Int a.BigInt(&aBig) var c Element - c.Neg(&a) + c.Square(&a) var d, e big.Int - d.Neg(&aBig).Mod(&d, Modulus()) + d.Mul(&aBig, &aBig).Mod(&d, Modulus()) if c.BigInt(&e).Cmp(&d) != 0 { - t.Fatal("Neg failed special test values") + t.Fatal("Square failed special test values") } } } @@ -711,21 +1308,80 @@ func TestElementNeg(t *testing.T) { } -func TestElementNewElement(t *testing.T) { - assert := require.New(t) - +func TestElementInverse(t *testing.T) { t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } - e := NewElement(1) - assert.True(e.IsOne()) + properties := gopter.NewProperties(parameters) - e = NewElement(0) - assert.True(e.IsZero()) -} + genA := gen() + properties.Property("Inverse: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { -func TestElementFromMont(t *testing.T) { + var b Element + + b.Inverse(&a.element) + a.element.Inverse(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Inverse: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Inverse(&a.element) + + var d, e big.Int + d.ModInverse(&a.bigint, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Inverse: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Inverse(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Inverse(&a) + + var d, e big.Int + d.ModInverse(&aBig, Modulus()) + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Inverse failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSqrt(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { @@ -738,27 +1394,791 @@ func TestElementFromMont(t *testing.T) { genA := gen() - properties.Property("Assembly implementation must be consistent with generic one", prop.ForAll( + properties.Property("Sqrt: having the receiver as operand should output the same result", prop.ForAll( func(a testPairElement) bool { - c := a.element - d := a.element - c.fromMont() - _fromMontGeneric(&d) - return c.Equal(&d) + + b := a.element + + b.Sqrt(&a.element) + a.element.Sqrt(&a.element) + return a.element.Equal(&b) }, genA, )) - properties.Property("x.fromMont().toMont() == x", prop.ForAll( + properties.Property("Sqrt: operation result must match big.Int result", prop.ForAll( func(a testPairElement) bool { - c := a.element - c.fromMont().toMont() - return c.Equal(&a.element) + var c Element + c.Sqrt(&a.element) + + var d, e big.Int + d.ModSqrt(&a.bigint, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) - properties.TestingRun(t, gopter.ConsoleReporter(false)) + properties.Property("Sqrt: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Sqrt(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Sqrt(&a) + + var d, e big.Int + d.ModSqrt(&aBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Sqrt failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementDouble(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Double: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Double(&a.element) + a.element.Double(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Double: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Double(&a.element) + + var d, e big.Int + d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Double: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Double(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Double(&a) + + var d, e big.Int + d.Lsh(&aBig, 1).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Double failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementNeg(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Neg: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Neg(&a.element) + a.element.Neg(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Neg: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Neg(&a.element) + + var d, e big.Int + d.Neg(&a.bigint).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Neg: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Neg(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Neg(&a) + + var d, e big.Int + d.Neg(&aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Neg failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementFixedExp(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + var ( + _bLegendreExponentElement *big.Int + _bSqrtExponentElement *big.Int + ) + + _bLegendreExponentElement, _ = new(big.Int).SetString("3f800000", 16) + const sqrtExponentElement = "3f" + _bSqrtExponentElement, _ = new(big.Int).SetString(sqrtExponentElement, 16) + + genA := gen() + + properties.Property(fmt.Sprintf("expBySqrtExp must match Exp(%s)", sqrtExponentElement), prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.expBySqrtExp(c) + d.Exp(d, _bSqrtExponentElement) + return c.Equal(&d) + }, + genA, + )) + + properties.Property("expByLegendreExp must match Exp(3f800000)", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.expByLegendreExp(c) + d.Exp(d, _bLegendreExponentElement) + return c.Equal(&d) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementHalve(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + var twoInv Element + twoInv.SetUint64(2) + twoInv.Inverse(&twoInv) + + properties.Property("z.Halve must match z / 2", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.Halve() + d.Mul(&d, &twoInv) + return c.Equal(&d) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func combineSelectionArguments(c int64, z int8) int { + if z%3 == 0 { + return 0 + } + return int(c) +} + +func TestElementSelect(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := genFull() + genB := genFull() + genC := ggen.Int64() //the condition + genZ := ggen.Int8() //to make zeros artificially more likely + + properties.Property("Select: must select correctly", prop.ForAll( + func(a, b Element, cond int64, z int8) bool { + condC := combineSelectionArguments(cond, z) + + var c Element + c.Select(condC, &a, &b) + + if condC == 0 { + return c.Equal(&a) + } + return c.Equal(&b) + }, + genA, + genB, + genC, + genZ, + )) + + properties.Property("Select: having the receiver as operand should output the same result", prop.ForAll( + func(a, b Element, cond int64, z int8) bool { + condC := combineSelectionArguments(cond, z) + + var c, d Element + d.Set(&a) + c.Select(condC, &a, &b) + a.Select(condC, &a, &b) + b.Select(condC, &d, &b) + return a.Equal(&b) && a.Equal(&c) && b.Equal(&c) + }, + genA, + genB, + genC, + genZ, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementSetInt64(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("z.SetInt64 must match z.SetString", prop.ForAll( + func(a testPairElement, v int64) bool { + c := a.element + d := a.element + + c.SetInt64(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, ggen.Int64(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementSetInterface(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genInt := ggen.Int + genInt8 := ggen.Int8 + genInt16 := ggen.Int16 + genInt32 := ggen.Int32 + genInt64 := ggen.Int64 + + genUint := ggen.UInt + genUint8 := ggen.UInt8 + genUint16 := ggen.UInt16 + genUint32 := ggen.UInt32 + genUint64 := ggen.UInt64 + + properties.Property("z.SetInterface must match z.SetString with int8", prop.ForAll( + func(a testPairElement, v int8) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt8(), + )) + + properties.Property("z.SetInterface must match z.SetString with int16", prop.ForAll( + func(a testPairElement, v int16) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt16(), + )) + + properties.Property("z.SetInterface must match z.SetString with int32", prop.ForAll( + func(a testPairElement, v int32) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt32(), + )) + + properties.Property("z.SetInterface must match z.SetString with int64", prop.ForAll( + func(a testPairElement, v int64) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt64(), + )) + + properties.Property("z.SetInterface must match z.SetString with int", prop.ForAll( + func(a testPairElement, v int) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint8", prop.ForAll( + func(a testPairElement, v uint8) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint8(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint16", prop.ForAll( + func(a testPairElement, v uint16) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint16(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint32", prop.ForAll( + func(a testPairElement, v uint32) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint32(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint64", prop.ForAll( + func(a testPairElement, v uint64) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint64(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint", prop.ForAll( + func(a testPairElement, v uint) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + + { + assert := require.New(t) + var e Element + r, err := e.SetInterface(nil) + assert.Nil(r) + assert.Error(err) + + var ptE *Element + var ptB *big.Int + + r, err = e.SetInterface(ptE) + assert.Nil(r) + assert.Error(err) + ptE = new(Element).SetOne() + r, err = e.SetInterface(ptE) + assert.NoError(err) + assert.True(r.IsOne()) + + r, err = e.SetInterface(ptB) + assert.Nil(r) + assert.Error(err) + + } +} + +func TestElementNegativeExp(t *testing.T) { + t.Parallel() + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("x⁻ᵏ == 1/xᵏ", prop.ForAll( + func(a, b testPairElement) bool { + + var nb, d, e big.Int + nb.Neg(&b.bigint) + + var c Element + c.Exp(a.element, &nb) + + d.Exp(&a.bigint, &nb, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementNewElement(t *testing.T) { + assert := require.New(t) + + t.Parallel() + + e := NewElement(1) + assert.True(e.IsOne()) + + e = NewElement(0) + assert.True(e.IsZero()) +} + +func TestElementBatchInvert(t *testing.T) { + assert := require.New(t) + + t.Parallel() + + // ensure batchInvert([x]) == invert(x) + for i := int64(-1); i <= 2; i++ { + var e, eInv Element + e.SetInt64(i) + eInv.Inverse(&e) + + a := []Element{e} + aInv := BatchInvert(a) + + assert.True(aInv[0].Equal(&eInv), "batchInvert != invert") + + } + + // test x * x⁻¹ == 1 + tData := [][]int64{ + {-1, 1, 2, 3}, + {0, -1, 1, 2, 3, 0}, + {0, -1, 1, 0, 2, 3, 0}, + {-1, 1, 0, 2, 3}, + {0, 0, 1}, + {1, 0, 0}, + {0, 0, 0}, + } + + for _, t := range tData { + a := make([]Element, len(t)) + for i := 0; i < len(a); i++ { + a[i].SetInt64(t[i]) + } + + aInv := BatchInvert(a) + + assert.True(len(aInv) == len(a)) + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + assert.True(aInv[i].IsZero(), "0⁻¹ != 0") + } else { + assert.True(a[i].Mul(&a[i], &aInv[i]).IsOne(), "x * x⁻¹ != 1") + } + } + } + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("batchInvert --> x * x⁻¹ == 1", prop.ForAll( + func(tp testPairElement, r uint8) bool { + + a := make([]Element, r) + if r != 0 { + a[0] = tp.element + + } + one := One() + for i := 1; i < len(a); i++ { + a[i].Add(&a[i-1], &one) + } + + aInv := BatchInvert(a) + + assert.True(len(aInv) == len(a)) + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + if !aInv[i].IsZero() { + return false + } + } else { + if !a[i].Mul(&a[i], &aInv[i]).IsOne() { + return false + } + } + } + return true + }, + genA, ggen.UInt8(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementFromMont(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Assembly implementation must be consistent with generic one", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.fromMont() + _fromMontGeneric(&d) + return c.Equal(&d) + }, + genA, + )) + + properties.Property("x.fromMont().toMont() == x", prop.ForAll( + func(a testPairElement) bool { + c := a.element + c.fromMont().toMont() + return c.Equal(&a.element) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementJSON(t *testing.T) { + assert := require.New(t) + + type S struct { + A Element + B [3]Element + C *Element + D *Element + } + + // encode to JSON + var s S + s.A.SetString("-1") + s.B[2].SetUint64(42) + s.D = new(Element).SetUint64(8000) + + encoded, err := json.Marshal(&s) + assert.NoError(err) + // we may need to adjust "42" and "8000" values for some moduli; see Text() method for more details. + formatValue := func(v int64) string { + var a big.Int + a.SetInt64(v) + a.Mod(&a, Modulus()) + const maxUint16 = 65535 + var aNeg big.Int + aNeg.Neg(&a).Mod(&aNeg, Modulus()) + if aNeg.Uint64() != 0 && aNeg.Uint64() <= maxUint16 { + return "-" + aNeg.Text(10) + } + return a.Text(10) + } + expected := fmt.Sprintf("{\"A\":%s,\"B\":[0,0,%s],\"C\":null,\"D\":%s}", formatValue(-1), formatValue(42), formatValue(8000)) + assert.Equal(expected, string(encoded)) + + // decode valid + var decoded S + err = json.Unmarshal([]byte(expected), &decoded) + assert.NoError(err) + + assert.Equal(s, decoded, "element -> json -> element round trip failed") + + // decode hex and string values + withHexValues := "{\"A\":\"-1\",\"B\":[0,\"0x00000\",\"0x2A\"],\"C\":null,\"D\":\"8000\"}" + + var decodedS S + err = json.Unmarshal([]byte(withHexValues), &decodedS) + assert.NoError(err) + + assert.Equal(s, decodedS, " json with strings -> element failed") + } type testPairElement struct { @@ -771,17 +2191,17 @@ func gen() gopter.Gen { var g testPairElement g.element = Element{ - genParams.NextUint64(), + uint32(genParams.NextUint64()), } - if qElement[0] != ^uint64(0) { + if qElement[0] != ^uint32(0) { g.element[0] %= (qElement[0] + 1) } for !g.element.smallerThanModulus() { g.element = Element{ - genParams.NextUint64(), + uint32(genParams.NextUint64()), } - if qElement[0] != ^uint64(0) { + if qElement[0] != ^uint32(0) { g.element[0] %= (qElement[0] + 1) } } @@ -796,18 +2216,18 @@ func genRandomFq(genParams *gopter.GenParameters) Element { var g Element g = Element{ - genParams.NextUint64(), + uint32(genParams.NextUint64()), } - if qElement[0] != ^uint64(0) { + if qElement[0] != ^uint32(0) { g[0] %= (qElement[0] + 1) } for !g.smallerThanModulus() { g = Element{ - genParams.NextUint64(), + uint32(genParams.NextUint64()), } - if qElement[0] != ^uint64(0) { + if qElement[0] != ^uint32(0) { g[0] %= (qElement[0] + 1) } } @@ -819,8 +2239,8 @@ func genFull() gopter.Gen { return func(genParams *gopter.GenParameters) *gopter.GenResult { a := genRandomFq(genParams) - var carry uint64 - a[0], _ = bits.Add64(a[0], qElement[0], carry) + var carry uint32 + a[0], _ = bits.Add32(a[0], qElement[0], carry) genResult := gopter.NewGenResult(a, gopter.NoShrinker) return genResult From dddb22da738ffaab3e2d3b57062c2936a0335299 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 25 Nov 2024 14:22:10 -0600 Subject: [PATCH 26/74] feat: add babybear and koalabear --- .github/workflows/pr.yml | 5 +- .github/workflows/push.yml | 3 + ecc/bls12-377/fp/element.go | 1 - ecc/bls12-377/fr/element.go | 1 - ecc/bls12-381/fp/element.go | 1 - ecc/bls12-381/fr/element.go | 1 - ecc/bls24-315/fp/element.go | 1 - ecc/bls24-315/fr/element.go | 1 - ecc/bls24-317/fp/element.go | 1 - ecc/bls24-317/fr/element.go | 1 - ecc/bn254/fp/element.go | 1 - ecc/bn254/fr/element.go | 1 - ecc/bw6-633/fp/element.go | 1 - ecc/bw6-633/fr/element.go | 1 - ecc/bw6-761/fp/element.go | 1 - ecc/bw6-761/fr/element.go | 1 - ecc/secp256k1/fp/element.go | 1 - ecc/secp256k1/fr/element.go | 1 - ecc/stark-curve/fp/element.go | 1 - ecc/stark-curve/fr/element.go | 1 - field/babybear/arith.go | 60 + field/babybear/doc.go | 53 + field/babybear/element.go | 975 +++++++ field/babybear/element_exp.go | 92 + field/babybear/element_purego.go | 127 + field/babybear/element_test.go | 2256 +++++++++++++++++ field/babybear/vector.go | 303 +++ field/babybear/vector_purego.go | 54 + field/babybear/vector_test.go | 365 +++ .../internal/templates/element/base.go | 44 +- .../internal/templates/element/conv.go | 29 +- .../internal/templates/element/tests.go | 12 +- .../templates/element/tests_vector.go | 9 +- field/goldilocks/element.go | 5 +- field/goldilocks/internal/main.go | 21 - field/internal/addchain/3c000000 | Bin 0 -> 206 bytes field/{koalabear => }/internal/addchain/3f | Bin .../internal/addchain/3f800000 | Bin field/internal/addchain/7 | Bin 0 -> 68 bytes .../internal/addchain/7fffffff | Bin 233 -> 232 bytes .../internal/addchain/7fffffff80000000 | Bin 394 -> 393 bytes field/internal/main.go | 36 + field/koalabear/element.go | 27 +- field/koalabear/element_exp.go | 122 + field/koalabear/internal/main.go | 21 - field/koalabear/vector.go | 303 +++ field/koalabear/vector_purego.go | 54 + field/koalabear/vector_test.go | 365 +++ 48 files changed, 5244 insertions(+), 115 deletions(-) create mode 100644 field/babybear/arith.go create mode 100644 field/babybear/doc.go create mode 100644 field/babybear/element.go create mode 100644 field/babybear/element_exp.go create mode 100644 field/babybear/element_purego.go create mode 100644 field/babybear/element_test.go create mode 100644 field/babybear/vector.go create mode 100644 field/babybear/vector_purego.go create mode 100644 field/babybear/vector_test.go delete mode 100644 field/goldilocks/internal/main.go create mode 100644 field/internal/addchain/3c000000 rename field/{koalabear => }/internal/addchain/3f (100%) rename field/{koalabear => }/internal/addchain/3f800000 (100%) create mode 100644 field/internal/addchain/7 rename field/{goldilocks => }/internal/addchain/7fffffff (66%) rename field/{goldilocks => }/internal/addchain/7fffffff80000000 (79%) create mode 100644 field/internal/main.go create mode 100644 field/koalabear/element_exp.go delete mode 100644 field/koalabear/internal/main.go create mode 100644 field/koalabear/vector.go create mode 100644 field/koalabear/vector_purego.go create mode 100644 field/koalabear/vector_test.go diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 6da517779b..a6e3b4607e 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -61,7 +61,10 @@ jobs: go test -json -v -short -timeout=30m ./... 2>&1 | gotestfmt -hide=all | tee /tmp/gotest.log go test -json -v -tags=purego -timeout=30m ./... 2>&1 | gotestfmt -hide=all | tee -a /tmp/gotest.log go test -json -v -race -timeout=30m ./ecc/bn254/... 2>&1 | gotestfmt -hide=all | tee -a /tmp/gotest.log - GOARCH=386 go test -json -short -v -timeout=30m ./ecc/bn254/... 2>&1 | gotestfmt -hide=all | tee -a /tmp/gotest.log + GOARCH=386 go test -json -short -v -timeout=30m ./ecc/bn254/... 2>&1 | gotestfmt -hide=all | tee -a /tmp/gotest.log + GOARCH=386 go test -json -short -v -timeout=30m ./field/goldilocks 2>&1 | gotestfmt -hide=all | tee -a /tmp/gotest.log + GOARCH=386 go test -json -short -v -timeout=30m ./field/koalabear 2>&1 | gotestfmt -hide=all | tee -a /tmp/gotest.log + GOARCH=386 go test -json -short -v -timeout=30m ./field/babybear 2>&1 | gotestfmt -hide=all | tee -a /tmp/gotest.log slack-notifications: diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 1ed9bfa9bc..a762da4cb4 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -73,6 +73,9 @@ jobs: go test -json -v -tags=purego -timeout=30m ./... 2>&1 | gotestfmt -hide=all | tee -a /tmp/gotest.log go test -json -v -race -timeout=30m ./ecc/bn254/... 2>&1 | gotestfmt -hide=all | tee -a /tmp/gotest.log GOARCH=386 go test -json -short -v -timeout=30m ./ecc/bn254/... 2>&1 | gotestfmt -hide=all | tee -a /tmp/gotest.log + GOARCH=386 go test -json -short -v -timeout=30m ./field/goldilocks 2>&1 | gotestfmt -hide=all | tee -a /tmp/gotest.log + GOARCH=386 go test -json -short -v -timeout=30m ./field/koalabear 2>&1 | gotestfmt -hide=all | tee -a /tmp/gotest.log + GOARCH=386 go test -json -short -v -timeout=30m ./field/babybear 2>&1 | gotestfmt -hide=all | tee -a /tmp/gotest.log slack-notifications: diff --git a/ecc/bls12-377/fp/element.go b/ecc/bls12-377/fp/element.go index 393f45744d..5266ace480 100644 --- a/ecc/bls12-377/fp/element.go +++ b/ecc/bls12-377/fp/element.go @@ -1119,7 +1119,6 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bls12-377/fr/element.go b/ecc/bls12-377/fr/element.go index af277e8bb1..d71abadc0d 100644 --- a/ecc/bls12-377/fr/element.go +++ b/ecc/bls12-377/fr/element.go @@ -960,7 +960,6 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bls12-381/fp/element.go b/ecc/bls12-381/fp/element.go index f0bcfe51bc..5808de88cd 100644 --- a/ecc/bls12-381/fp/element.go +++ b/ecc/bls12-381/fp/element.go @@ -1119,7 +1119,6 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bls12-381/fr/element.go b/ecc/bls12-381/fr/element.go index dc38f08cd3..f5f88277b8 100644 --- a/ecc/bls12-381/fr/element.go +++ b/ecc/bls12-381/fr/element.go @@ -960,7 +960,6 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bls24-315/fp/element.go b/ecc/bls24-315/fp/element.go index 4ab67695e3..6625335bc9 100644 --- a/ecc/bls24-315/fp/element.go +++ b/ecc/bls24-315/fp/element.go @@ -1035,7 +1035,6 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bls24-315/fr/element.go b/ecc/bls24-315/fr/element.go index abdb822acf..7c297b79dc 100644 --- a/ecc/bls24-315/fr/element.go +++ b/ecc/bls24-315/fr/element.go @@ -960,7 +960,6 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bls24-317/fp/element.go b/ecc/bls24-317/fp/element.go index 77818de479..a4a2ee282e 100644 --- a/ecc/bls24-317/fp/element.go +++ b/ecc/bls24-317/fp/element.go @@ -1035,7 +1035,6 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bls24-317/fr/element.go b/ecc/bls24-317/fr/element.go index 3aefaebe62..bf936ea4e9 100644 --- a/ecc/bls24-317/fr/element.go +++ b/ecc/bls24-317/fr/element.go @@ -960,7 +960,6 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bn254/fp/element.go b/ecc/bn254/fp/element.go index 25fcdb67cc..2b207e73b7 100644 --- a/ecc/bn254/fp/element.go +++ b/ecc/bn254/fp/element.go @@ -960,7 +960,6 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bn254/fr/element.go b/ecc/bn254/fr/element.go index 3650c954c5..eb95ff30e6 100644 --- a/ecc/bn254/fr/element.go +++ b/ecc/bn254/fr/element.go @@ -960,7 +960,6 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bw6-633/fp/element.go b/ecc/bw6-633/fp/element.go index 7656002f47..2165d16e17 100644 --- a/ecc/bw6-633/fp/element.go +++ b/ecc/bw6-633/fp/element.go @@ -1515,7 +1515,6 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bw6-633/fr/element.go b/ecc/bw6-633/fr/element.go index 8841cd342c..f019a15673 100644 --- a/ecc/bw6-633/fr/element.go +++ b/ecc/bw6-633/fr/element.go @@ -1035,7 +1035,6 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bw6-761/fp/element.go b/ecc/bw6-761/fp/element.go index 8cdd31218e..3c8dcc99aa 100644 --- a/ecc/bw6-761/fp/element.go +++ b/ecc/bw6-761/fp/element.go @@ -1749,7 +1749,6 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bw6-761/fr/element.go b/ecc/bw6-761/fr/element.go index 6784bc911f..a11193a4de 100644 --- a/ecc/bw6-761/fr/element.go +++ b/ecc/bw6-761/fr/element.go @@ -1119,7 +1119,6 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/secp256k1/fp/element.go b/ecc/secp256k1/fp/element.go index 73045a133c..382379a65d 100644 --- a/ecc/secp256k1/fp/element.go +++ b/ecc/secp256k1/fp/element.go @@ -988,7 +988,6 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/secp256k1/fr/element.go b/ecc/secp256k1/fr/element.go index e2f81b66b3..5b89009e92 100644 --- a/ecc/secp256k1/fr/element.go +++ b/ecc/secp256k1/fr/element.go @@ -988,7 +988,6 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/stark-curve/fp/element.go b/ecc/stark-curve/fp/element.go index 1c53dcb090..990494c9f5 100644 --- a/ecc/stark-curve/fp/element.go +++ b/ecc/stark-curve/fp/element.go @@ -960,7 +960,6 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/stark-curve/fr/element.go b/ecc/stark-curve/fr/element.go index 216e287ebb..bd0d4aae0d 100644 --- a/ecc/stark-curve/fr/element.go +++ b/ecc/stark-curve/fr/element.go @@ -960,7 +960,6 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/field/babybear/arith.go b/field/babybear/arith.go new file mode 100644 index 0000000000..03c952e9fb --- /dev/null +++ b/field/babybear/arith.go @@ -0,0 +1,60 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +import ( + "math/bits" +) + +// madd0 hi = a*b + c (discards lo bits) +func madd0(a, b, c uint32) (hi uint32) { + var carry, lo uint32 + hi, lo = bits.Mul32(a, b) + _, carry = bits.Add32(lo, c, 0) + hi, _ = bits.Add32(hi, 0, carry) + return +} + +// madd1 hi, lo = a*b + c +func madd1(a, b, c uint32) (hi uint32, lo uint32) { + var carry uint32 + hi, lo = bits.Mul32(a, b) + lo, carry = bits.Add32(lo, c, 0) + hi, _ = bits.Add32(hi, 0, carry) + return +} + +// madd2 hi, lo = a*b + c + d +func madd2(a, b, c, d uint32) (hi uint32, lo uint32) { + var carry uint32 + hi, lo = bits.Mul32(a, b) + c, carry = bits.Add32(c, d, 0) + hi, _ = bits.Add32(hi, 0, carry) + lo, carry = bits.Add32(lo, c, 0) + hi, _ = bits.Add32(hi, 0, carry) + return +} + +func madd3(a, b, c, d, e uint32) (hi uint32, lo uint32) { + var carry uint32 + hi, lo = bits.Mul32(a, b) + c, carry = bits.Add32(c, d, 0) + hi, _ = bits.Add32(hi, 0, carry) + lo, carry = bits.Add32(lo, c, 0) + hi, _ = bits.Add32(hi, e, carry) + return +} diff --git a/field/babybear/doc.go b/field/babybear/doc.go new file mode 100644 index 0000000000..65011bef3f --- /dev/null +++ b/field/babybear/doc.go @@ -0,0 +1,53 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// Package babybear contains field arithmetic operations for modulus = 0x78000001. +// +// The API is similar to math/big (big.Int), but the operations are significantly faster (up to 20x for the modular multiplication on amd64, see also https://hackmd.io/@gnark/modular_multiplication) +// +// The modulus is hardcoded in all the operations. +// +// Field elements are represented as an array, and assumed to be in Montgomery form in all methods: +// +// type Element [1]uint64 +// +// # Usage +// +// Example API signature: +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element +// +// and can be used like so: +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) +// +// Modulus q = +// +// q[base10] = 2013265921 +// q[base16] = 0x78000001 +// +// # Warning +// +// This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. +package babybear diff --git a/field/babybear/element.go b/field/babybear/element.go new file mode 100644 index 0000000000..b4f4fb4f23 --- /dev/null +++ b/field/babybear/element.go @@ -0,0 +1,975 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +import ( + "crypto/rand" + "encoding/binary" + "errors" + "io" + "math/big" + "math/bits" + "reflect" + "strconv" + "strings" + + "github.com/bits-and-blooms/bitset" + "github.com/consensys/gnark-crypto/field/hash" + "github.com/consensys/gnark-crypto/field/pool" +) + +// Element represents a field element stored on 1 words (uint32) +// +// Element are assumed to be in Montgomery form in all methods. +// +// Modulus q = +// +// q[base10] = 2013265921 +// q[base16] = 0x78000001 +// +// # Warning +// +// This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. +type Element [1]uint32 + +const ( + Limbs = 1 // number of 32 bits words needed to represent a Element + Bits = 31 // number of bits needed to represent a Element + Bytes = 4 // number of bytes needed to represent a Element +) + +// Field modulus q +const ( + q0 uint32 = 2013265921 + q uint32 = q0 +) + +var qElement = Element{ + q0, +} + +var _modulus big.Int // q stored as big.Int + +// Modulus returns q as a big.Int +// +// q[base10] = 2013265921 +// q[base16] = 0x78000001 +func Modulus() *big.Int { + return new(big.Int).Set(&_modulus) +} + +// q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +// used for Montgomery reduction +const qInvNeg uint32 = 2013265919 + +func init() { + _modulus.SetString("78000001", 16) +} + +// NewElement returns a new Element from a uint64 value +// +// it is equivalent to +// +// var v Element +// v.SetUint64(...) +func NewElement(v uint64) Element { + z := Element{uint32(v % uint64(q0))} + z.Mul(&z, &rSquare) + return z +} + +// SetUint64 sets z to v and returns z +func (z *Element) SetUint64(v uint64) *Element { + // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form + *z = Element{uint32(v % uint64(q0))} + return z.Mul(z, &rSquare) // z.toMont() +} + +// SetInt64 sets z to v and returns z +func (z *Element) SetInt64(v int64) *Element { + + // absolute value of v + m := v >> 63 + z.SetUint64(uint64((v ^ m) - m)) + + if m != 0 { + // v is negative + z.Neg(z) + } + + return z +} + +// Set z = x and returns z +func (z *Element) Set(x *Element) *Element { + z[0] = x[0] + return z +} + +// SetInterface converts provided interface into Element +// returns an error if provided type is not supported +// supported types: +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte +func (z *Element) SetInterface(i1 interface{}) (*Element, error) { + if i1 == nil { + return nil, errors.New("can't set babybear.Element with ") + } + + switch c1 := i1.(type) { + case Element: + return z.Set(&c1), nil + case *Element: + if c1 == nil { + return nil, errors.New("can't set babybear.Element with ") + } + return z.Set(c1), nil + case uint8: + return z.SetUint64(uint64(c1)), nil + case uint16: + return z.SetUint64(uint64(c1)), nil + case uint32: + return z.SetUint64(uint64(c1)), nil + case uint: + return z.SetUint64(uint64(c1)), nil + case uint64: + return z.SetUint64(c1), nil + case int8: + return z.SetInt64(int64(c1)), nil + case int16: + return z.SetInt64(int64(c1)), nil + case int32: + return z.SetInt64(int64(c1)), nil + case int64: + return z.SetInt64(c1), nil + case int: + return z.SetInt64(int64(c1)), nil + case string: + return z.SetString(c1) + case *big.Int: + if c1 == nil { + return nil, errors.New("can't set babybear.Element with ") + } + return z.SetBigInt(c1), nil + case big.Int: + return z.SetBigInt(&c1), nil + case []byte: + return z.SetBytes(c1), nil + default: + return nil, errors.New("can't set babybear.Element from type " + reflect.TypeOf(i1).String()) + } +} + +// SetZero z = 0 +func (z *Element) SetZero() *Element { + z[0] = 0 + return z +} + +// SetOne z = 1 (in Montgomery form) +func (z *Element) SetOne() *Element { + z[0] = 268435454 + return z +} + +// Div z = x*y⁻¹ (mod q) +func (z *Element) Div(x, y *Element) *Element { + var yInv Element + yInv.Inverse(y) + z.Mul(x, &yInv) + return z +} + +// Equal returns z == x; constant-time +func (z *Element) Equal(x *Element) bool { + return z.NotEqual(x) == 0 +} + +// NotEqual returns 0 if and only if z == x; constant-time +func (z *Element) NotEqual(x *Element) uint32 { + return (z[0] ^ x[0]) +} + +// IsZero returns z == 0 +func (z *Element) IsZero() bool { + return (z[0]) == 0 +} + +// IsOne returns z == 1 +func (z *Element) IsOne() bool { + return z[0] == 268435454 +} + +// IsUint64 reports whether z can be represented as an uint64. +func (z *Element) IsUint64() bool { + return true +} + +// Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. +func (z *Element) Uint64() uint64 { + return uint64(z.Bits()[0]) +} + +// FitsOnOneWord reports whether z words (except the least significant word) are 0 +// +// It is the responsibility of the caller to convert from Montgomery to Regular form if needed. +func (z *Element) FitsOnOneWord() bool { + return true +} + +// Cmp compares (lexicographic order) z and x and returns: +// +// -1 if z < x +// 0 if z == x +// +1 if z > x +func (z *Element) Cmp(x *Element) int { + _z := z.Bits() + _x := x.Bits() + if _z[0] > _x[0] { + return 1 + } else if _z[0] < _x[0] { + return -1 + } + return 0 +} + +// LexicographicallyLargest returns true if this element is strictly lexicographically +// larger than its negation, false otherwise +func (z *Element) LexicographicallyLargest() bool { + // adapted from github.com/zkcrypto/bls12_381 + // we check if the element is larger than (q-1) / 2 + // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 + + _z := z.Bits() + + var b uint32 + _, b = bits.Sub32(_z[0], 1006632961, 0) + + return b == 0 +} + +// SetRandom sets z to a uniform random value in [0, q). +// +// This might error only if reading from crypto/rand.Reader errors, +// in which case, value of z is undefined. +func (z *Element) SetRandom() (*Element, error) { + // this code is generated for all modulus + // and derived from go/src/crypto/rand/util.go + + // l is number of limbs * 8; the number of bytes needed to reconstruct 1 uint64 + const l = 8 + + // bitLen is the maximum bit length needed to encode a value < q. + const bitLen = 31 + + // k is the maximum byte length needed to encode a value < q. + const k = (bitLen + 7) / 8 + + // b is the number of bits in the most significant byte of q-1. + b := uint(bitLen % 8) + if b == 0 { + b = 8 + } + + var bytes [l]byte + + for { + // note that bytes[k:l] is always 0 + if _, err := io.ReadFull(rand.Reader, bytes[:k]); err != nil { + return nil, err + } + + // Clear unused bits in in the most significant byte to increase probability + // that the candidate is < q. + bytes[k-1] &= uint8(int(1<> 1 + z[0] >>= 1 + +} + +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// sets and returns z = z * 1 +func (z *Element) fromMont() *Element { + fromMont(z) + return z +} + +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + z[0], _ = bits.Add32(x[0], y[0], 0) + if z[0] >= q { + z[0] -= q + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + z[0] = (x[0] << 1) + if z[0] >= q { + z[0] -= q + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint32 + z[0], b = bits.Sub32(x[0], y[0], 0) + if b != 0 { + z[0] += q + } + return z +} + +// Neg z = q - x +func (z *Element) Neg(x *Element) *Element { + if x.IsZero() { + z.SetZero() + return z + } + z[0] = q - x[0] + return z +} + +// Select is a constant-time conditional move. +// If c=0, z = x0. Else z = x1 +func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { + cC := uint32((int64(c) | -int64(c)) >> 63) // "canonicized" into: 0 if c=0, -1 otherwise + z[0] = x0[0] ^ cC&(x0[0]^x1[0]) + return z +} + +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. +func _mulGeneric(z, x, y *Element) { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + var t [2]uint32 + var D uint32 + var m, C uint32 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul32(y[0], x[0]) + + t[1], D = bits.Add32(t[1], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + + t[0], C = bits.Add32(t[1], C, 0) + t[1], _ = bits.Add32(0, D, C) + + if t[1] != 0 { + // we need to reduce, we have a result on 2 words + z[0], _ = bits.Sub32(t[0], q0, 0) + return + } + + // copy t into z + z[0] = t[0] + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + z[0] -= q + } +} + +func _fromMontGeneric(z *Element) { + // the following lines implement z = z * 1 + // with a modified CIOS montgomery multiplication + // see Mul for algorithm documentation + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + z[0] = C + } + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + z[0] -= q + } +} + +func _reduceGeneric(z *Element) { + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + z[0] -= q + } +} + +// BatchInvert returns a new slice with every element inverted. +// Uses Montgomery batch inversion trick +func BatchInvert(a []Element) []Element { + res := make([]Element, len(a)) + if len(a) == 0 { + return res + } + + zeroes := bitset.New(uint(len(a))) + accumulator := One() + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + zeroes.Set(uint(i)) + continue + } + res[i] = accumulator + accumulator.Mul(&accumulator, &a[i]) + } + + accumulator.Inverse(&accumulator) + + for i := len(a) - 1; i >= 0; i-- { + if zeroes.Test(uint(i)) { + continue + } + res[i].Mul(&res[i], &accumulator) + accumulator.Mul(&accumulator, &a[i]) + } + + return res +} + +func _butterflyGeneric(a, b *Element) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// BitLen returns the minimum number of bits needed to represent z +// returns 0 if z == 0 +func (z *Element) BitLen() int { + return bits.Len32(z[0]) +} + +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := hash.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := pool.BigInt.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + pool.BigInt.Put(vv) + + return res, nil +} + +// Exp z = xᵏ (mod q) +func (z *Element) Exp(x Element, k *big.Int) *Element { + if k.IsUint64() && k.Uint64() == 0 { + return z.SetOne() + } + + e := k + if k.Sign() == -1 { + // negative k, we invert + // if k < 0: xᵏ (mod q) == (x⁻¹)ᵏ (mod q) + x.Inverse(&x) + + // we negate k in a temp big.Int since + // Int.Bit(_) of k and -k is different + e = pool.BigInt.Get() + defer pool.BigInt.Put(e) + e.Neg(k) + } + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z +} + +// rSquare where r is the Montgommery constant +// see section 2.3.2 of Tolga Acar's thesis +// https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf +var rSquare = Element{ + 1172168163, +} + +// toMont converts z to Montgomery form +// sets and returns z = z * r² +func (z *Element) toMont() *Element { + return z.Mul(z, &rSquare) +} + +// String returns the decimal representation of z as generated by +// z.Text(10). +func (z *Element) String() string { + return z.Text(10) +} + +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint32(b[0:4], z[0]) + + return res.SetBytes(b[:]) +} + +// Text returns the string representation of z in the given base. +// Base must be between 2 and 36, inclusive. The result uses the +// lower-case letters 'a' to 'z' for digit values 10 to 35. +// No prefix (such as "0x") is added to the string. If z is a nil +// pointer it returns "". +// If base == 10 and -z fits in a uint16 prefix "-" is added to the string. +func (z *Element) Text(base int) string { + if base < 2 || base > 36 { + panic("invalid base") + } + if z == nil { + return "" + } + + const maxUint16 = 65535 + if base == 10 { + var zzNeg Element + zzNeg.Neg(z) + zzNeg.fromMont() + if zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { + return "-" + strconv.FormatUint(uint64(zzNeg[0]), base) + } + } + zz := z.Bits() + return strconv.FormatUint(uint64(zz[0]), base) +} + +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) +} + +// ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead +func (z Element) ToBigIntRegular(res *big.Int) *big.Int { + z.fromMont() + return z.toBigInt(res) +} + +// Bits provides access to z by returning its value as a little-endian [1]uint32 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [1]uint32 { + _z := *z + fromMont(&_z) + return _z +} + +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) + return +} + +// Marshal returns the value of z as a big-endian byte slice +func (z *Element) Marshal() []byte { + b := z.Bytes() + return b[:] +} + +// Unmarshal is an alias for SetBytes, it sets z to the value of e. +func (z *Element) Unmarshal(e []byte) { + z.SetBytes(e) +} + +// SetBytes interprets e as the bytes of a big-endian unsigned integer, +// sets z to that value, and returns z. +func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. + // get a big int from our pool + vv := pool.BigInt.Get() + vv.SetBytes(e) + + // set big int + z.SetBigInt(vv) + + // put temporary object back in pool + pool.BigInt.Put(vv) + + return z +} + +// SetBytesCanonical interprets e as the bytes of a big-endian 4-byte integer. +// If e is not a 4-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid babybear.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + +// SetBigInt sets z to v and returns z +func (z *Element) SetBigInt(v *big.Int) *Element { + z.SetZero() + + var zero big.Int + + // fast path + c := v.Cmp(&_modulus) + if c == 0 { + // v == 0 + return z + } else if c != 1 && v.Cmp(&zero) != -1 { + // 0 < v < q + return z.setBigInt(v) + } + + // get temporary big int from the pool + vv := pool.BigInt.Get() + + // copy input + modular reduction + vv.Mod(v, &_modulus) + + // set big int byte value + z.setBigInt(vv) + + // release object into pool + pool.BigInt.Put(vv) + return z +} + +// setBigInt assumes 0 ⩽ v < q +func (z *Element) setBigInt(v *big.Int) *Element { + vBits := v.Bits() + // we assume v < q, so even if big.Int words are on 64bits, we can safely cast them to 32bits + for i := 0; i < len(vBits); i++ { + z[i] = uint32(vBits[i]) + } + + return z.toMont() +} + +// SetString creates a big.Int with number and calls SetBigInt on z +// +// The number prefix determines the actual base: A prefix of +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 +// and no prefix is accepted. +// +// For base 16, lower and upper case letters are considered the same: +// The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. +// +// An underscore character ”_” may appear between a base +// prefix and an adjacent digit, and between successive digits; such +// underscores do not change the value of the number. +// Incorrect placement of underscores is reported as a panic if there +// are no other errors. +// +// If the number is invalid this method leaves z unchanged and returns nil, error. +func (z *Element) SetString(number string) (*Element, error) { + // get temporary big int from the pool + vv := pool.BigInt.Get() + + if _, ok := vv.SetString(number, 0); !ok { + return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) + } + + z.SetBigInt(vv) + + // release object into pool + pool.BigInt.Put(vv) + + return z, nil +} + +// MarshalJSON returns json encoding of z (z.Text(10)) +// If z == nil, returns null +func (z *Element) MarshalJSON() ([]byte, error) { + if z == nil { + return []byte("null"), nil + } + const maxSafeBound = 15 // we encode it as number if it's small + s := z.Text(10) + if len(s) <= maxSafeBound { + return []byte(s), nil + } + var sbb strings.Builder + sbb.WriteByte('"') + sbb.WriteString(s) + sbb.WriteByte('"') + return []byte(sbb.String()), nil +} + +// UnmarshalJSON accepts numbers and strings as input +// See Element.SetString for valid prefixes (0x, 0b, ...) +func (z *Element) UnmarshalJSON(data []byte) error { + s := string(data) + if len(s) > Bits*3 { + return errors.New("value too large (max = Element.Bits * 3)") + } + + // we accept numbers and strings, remove leading and trailing quotes if any + if len(s) > 0 && s[0] == '"' { + s = s[1:] + } + if len(s) > 0 && s[len(s)-1] == '"' { + s = s[:len(s)-1] + } + + // get temporary big int from the pool + vv := pool.BigInt.Get() + + if _, ok := vv.SetString(s, 0); !ok { + return errors.New("can't parse into a big.Int: " + s) + } + + z.SetBigInt(vv) + + // release object into pool + pool.BigInt.Put(vv) + return nil +} + +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 4-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint32((*b)[0:4]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid babybear.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint32((*b)[0:4], e[0]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint32((*b)[0:4]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid babybear.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint32((*b)[0:4], e[0]) +} + +func (littleEndian) String() string { return "LittleEndian" } + +// Legendre returns the Legendre symbol of z (either +1, -1, or 0.) +func (z *Element) Legendre() int { + var l Element + // z^((q-1)/2) + l.expByLegendreExp(*z) + + if l.IsZero() { + return 0 + } + + // if l == 1 + if l.IsOne() { + return 1 + } + return -1 +} + +// Sqrt z = √x (mod q) +// if the square root doesn't exist (x is not a square mod q) +// Sqrt leaves z unchanged and returns nil +func (z *Element) Sqrt(x *Element) *Element { + // q ≡ 1 (mod 4) + // see modSqrtTonelliShanks in math/big/int.go + // using https://www.maa.org/sites/default/files/pdf/upload_library/22/Polya/07468342.di020786.02p0470a.pdf + + var y, b, t, w Element + // w = x^((s-1)/2)) + w.expBySqrtExp(*x) + + // y = x^((s+1)/2)) = w * x + y.Mul(x, &w) + + // b = xˢ = w * w * x = y * x + b.Mul(&w, &y) + + // g = nonResidue ^ s + var g = Element{ + 66106732, + } + r := uint64(27) + + // compute legendre symbol + // t = x^((q-1)/2) = r-1 squaring of xˢ + t = b + for i := uint64(0); i < r-1; i++ { + t.Square(&t) + } + if t.IsZero() { + return z.SetZero() + } + if !t.IsOne() { + // t != 1, we don't have a square root + return nil + } + for { + var m uint64 + t = b + + // for t != 1 + for !t.IsOne() { + t.Square(&t) + m++ + } + + if m == 0 { + return z.Set(&y) + } + // t = g^(2^(r-m-1)) (mod q) + ge := int(r - m - 1) + t = g + for ge > 0 { + t.Square(&t) + ge-- + } + + g.Square(&t) + y.Mul(&y, &t) + b.Mul(&b, &g) + r = m + } +} + +// Inverse z = x⁻¹ (mod q) +// +// note: allocates a big.Int (math/big) +func (z *Element) Inverse(x *Element) *Element { + var _xNonMont big.Int + x.BigInt(&_xNonMont) + _xNonMont.ModInverse(&_xNonMont, Modulus()) + z.SetBigInt(&_xNonMont) + return z +} diff --git a/field/babybear/element_exp.go b/field/babybear/element_exp.go new file mode 100644 index 0000000000..44675979e6 --- /dev/null +++ b/field/babybear/element_exp.go @@ -0,0 +1,92 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +// expBySqrtExp is equivalent to z.Exp(x, 7) +// +// uses github.com/mmcloughlin/addchain v0.4.0 to generate a shorter addition chain +func (z *Element) expBySqrtExp(x Element) *Element { + // addition chain: + // + // _10 = 2*1 + // _11 = 1 + _10 + // _110 = 2*_11 + // return 1 + _110 + // + // Operations: 2 squares 2 multiplies + + // Allocate Temporaries. + var () + + // var + // Step 1: z = x^0x2 + z.Square(&x) + + // Step 2: z = x^0x3 + z.Mul(&x, z) + + // Step 3: z = x^0x6 + z.Square(z) + + // Step 4: z = x^0x7 + z.Mul(&x, z) + + return z +} + +// expByLegendreExp is equivalent to z.Exp(x, 3c000000) +// +// uses github.com/mmcloughlin/addchain v0.4.0 to generate a shorter addition chain +func (z *Element) expByLegendreExp(x Element) *Element { + // addition chain: + // + // _10 = 2*1 + // _11 = 1 + _10 + // _1100 = _11 << 2 + // _1111 = _11 + _1100 + // return _1111 << 26 + // + // Operations: 29 squares 2 multiplies + + // Allocate Temporaries. + var ( + t0 = new(Element) + ) + + // var t0 Element + // Step 1: z = x^0x2 + z.Square(&x) + + // Step 2: z = x^0x3 + z.Mul(&x, z) + + // Step 4: t0 = x^0xc + t0.Square(z) + for s := 1; s < 2; s++ { + t0.Square(t0) + } + + // Step 5: z = x^0xf + z.Mul(z, t0) + + // Step 31: z = x^0x3c000000 + for s := 0; s < 26; s++ { + z.Square(z) + } + + return z +} diff --git a/field/babybear/element_purego.go b/field/babybear/element_purego.go new file mode 100644 index 0000000000..a768309c08 --- /dev/null +++ b/field/babybear/element_purego.go @@ -0,0 +1,127 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + var y Element + y.SetUint64(3) + x.Mul(x, &y) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + var y Element + y.SetUint64(5) + x.Mul(x, &y) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y Element + y.SetUint64(13) + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // In fact, since the modulus R fits on one register, the CIOS algorithm gets reduced to standard REDC (textbook Montgomery reduction): + // hi, lo := x * y + // m := (lo * qInvNeg) mod R + // (*) r := (hi * R + lo + m * q) / R + // reduce r if necessary + + // On the emphasized line, we get r = hi + (lo + m * q) / R + // If we write hi2, lo2 = m * q then R | m * q - lo2 ⇒ R | (lo * qInvNeg) q - lo2 = -lo - lo2 + // This shows lo + lo2 = 0 mod R. i.e. lo + lo2 = 0 if lo = 0 and R otherwise. + // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) + // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. + + var r uint32 + hi, lo := bits.Mul32(x[0], y[0]) + if lo != 0 { + hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow + } + m := lo * qInvNeg + hi2, _ := bits.Mul32(m, q) + r, carry := bits.Add32(hi2, hi, 0) + + if carry != 0 || r >= q { + // we need to reduce + r -= q + } + z[0] = r + + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + // In fact, since the modulus R fits on one register, the CIOS algorithm gets reduced to standard REDC (textbook Montgomery reduction): + // hi, lo := x * y + // m := (lo * qInvNeg) mod R + // (*) r := (hi * R + lo + m * q) / R + // reduce r if necessary + + // On the emphasized line, we get r = hi + (lo + m * q) / R + // If we write hi2, lo2 = m * q then R | m * q - lo2 ⇒ R | (lo * qInvNeg) q - lo2 = -lo - lo2 + // This shows lo + lo2 = 0 mod R. i.e. lo + lo2 = 0 if lo = 0 and R otherwise. + // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) + // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. + + var r uint32 + hi, lo := bits.Mul32(x[0], x[0]) + if lo != 0 { + hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow + } + m := lo * qInvNeg + hi2, _ := bits.Mul32(m, q) + r, carry := bits.Add32(hi2, hi, 0) + + if carry != 0 || r >= q { + // we need to reduce + r -= q + } + z[0] = r + + return z +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/field/babybear/element_test.go b/field/babybear/element_test.go new file mode 100644 index 0000000000..14034e42a5 --- /dev/null +++ b/field/babybear/element_test.go @@ -0,0 +1,2256 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +import ( + "crypto/rand" + "encoding/json" + "fmt" + "math/big" + "math/bits" + + "testing" + + "github.com/leanovate/gopter" + ggen "github.com/leanovate/gopter/gen" + "github.com/leanovate/gopter/prop" + + "github.com/stretchr/testify/require" +) + +// ------------------------------------------------------------------------------------------------- +// benchmarks +// most benchmarks are rudimentary and should sample a large number of random inputs +// or be run multiple times to ensure it didn't measure the fastest path of the function + +var benchResElement Element + +func BenchmarkElementSelect(b *testing.B) { + var x, y Element + x.SetRandom() + y.SetRandom() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Select(i%3, &x, &y) + } +} + +func BenchmarkElementSetRandom(b *testing.B) { + var x Element + x.SetRandom() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = x.SetRandom() + } +} + +func BenchmarkElementSetBytes(b *testing.B) { + var x Element + x.SetRandom() + bb := x.Bytes() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchResElement.SetBytes(bb[:]) + } + +} + +func BenchmarkElementMulByConstants(b *testing.B) { + b.Run("mulBy3", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy3(&benchResElement) + } + }) + b.Run("mulBy5", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy5(&benchResElement) + } + }) + b.Run("mulBy13", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy13(&benchResElement) + } + }) +} + +func BenchmarkElementInverse(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchResElement.Inverse(&x) + } + +} + +func BenchmarkElementButterfly(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + Butterfly(&x, &benchResElement) + } +} + +func BenchmarkElementExp(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b1, _ := rand.Int(rand.Reader, Modulus()) + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Exp(x, b1) + } +} + +func BenchmarkElementDouble(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Double(&benchResElement) + } +} + +func BenchmarkElementAdd(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Add(&x, &benchResElement) + } +} + +func BenchmarkElementSub(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Sub(&x, &benchResElement) + } +} + +func BenchmarkElementNeg(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Neg(&benchResElement) + } +} + +func BenchmarkElementDiv(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Div(&x, &benchResElement) + } +} + +func BenchmarkElementFromMont(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.fromMont() + } +} + +func BenchmarkElementSquare(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Square(&benchResElement) + } +} + +func BenchmarkElementSqrt(b *testing.B) { + var a Element + a.SetUint64(4) + a.Neg(&a) + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Sqrt(&a) + } +} + +func BenchmarkElementMul(b *testing.B) { + x := Element{ + 1172168163, + } + benchResElement.SetOne() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Mul(&benchResElement, &x) + } +} + +func BenchmarkElementCmp(b *testing.B) { + x := Element{ + 1172168163, + } + benchResElement = x + benchResElement[0] = 0 + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Cmp(&x) + } +} + +func TestElementCmp(t *testing.T) { + var x, y Element + + if x.Cmp(&y) != 0 { + t.Fatal("x == y") + } + + one := One() + y.Sub(&y, &one) + + if x.Cmp(&y) != -1 { + t.Fatal("x < y") + } + if y.Cmp(&x) != 1 { + t.Fatal("x < y") + } + + x = y + if x.Cmp(&y) != 0 { + t.Fatal("x == y") + } + + x.Sub(&x, &one) + if x.Cmp(&y) != -1 { + t.Fatal("x < y") + } + if y.Cmp(&x) != 1 { + t.Fatal("x < y") + } +} + +func TestElementNegZero(t *testing.T) { + var a, b Element + b.SetZero() + for a.IsZero() { + a.SetRandom() + } + a.Neg(&b) + if !a.IsZero() { + t.Fatal("neg(0) != 0") + } +} + +// ------------------------------------------------------------------------------------------------- +// Gopter tests +// most of them are generated with a template + +const ( + nbFuzzShort = 200 + nbFuzz = 1000 +) + +// special values to be used in tests +var staticTestValues []Element + +func init() { + staticTestValues = append(staticTestValues, Element{}) // zero + staticTestValues = append(staticTestValues, One()) // one + staticTestValues = append(staticTestValues, rSquare) // r² + var e, one Element + one.SetOne() + e.Sub(&qElement, &one) + staticTestValues = append(staticTestValues, e) // q - 1 + e.Double(&one) + staticTestValues = append(staticTestValues, e) // 2 + + { + a := qElement + a[0]-- + staticTestValues = append(staticTestValues, a) + } + staticTestValues = append(staticTestValues, Element{0}) + staticTestValues = append(staticTestValues, Element{1}) + staticTestValues = append(staticTestValues, Element{2}) + + { + a := qElement + a[0]-- + staticTestValues = append(staticTestValues, a) + } + + { + a := qElement + a[0] = 0 + staticTestValues = append(staticTestValues, a) + } + +} + +func TestElementReduce(t *testing.T) { + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + s := testValues[i] + expected := s + reduce(&s) + _reduceGeneric(&expected) + if !s.Equal(&expected) { + t.Fatal("reduce failed: asm and generic impl don't match") + } + } + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := genFull() + + properties.Property("reduce should output a result smaller than modulus", prop.ForAll( + func(a Element) bool { + b := a + reduce(&a) + _reduceGeneric(&b) + return a.smallerThanModulus() && a.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementEqual(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("x.Equal(&y) iff x == y; likely false for random pairs", prop.ForAll( + func(a testPairElement, b testPairElement) bool { + return a.element.Equal(&b.element) == (a.element == b.element) + }, + genA, + genB, + )) + + properties.Property("x.Equal(&y) if x == y", prop.ForAll( + func(a testPairElement) bool { + b := a.element + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementBytes(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("SetBytes(Bytes()) should stay constant", prop.ForAll( + func(a testPairElement) bool { + var b Element + bytes := a.element.Bytes() + b.SetBytes(bytes[:]) + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementInverseExp(t *testing.T) { + // inverse must be equal to exp^-2 + exp := Modulus() + exp.Sub(exp, new(big.Int).SetUint64(2)) + + invMatchExp := func(a testPairElement) bool { + var b Element + b.Set(&a.element) + a.element.Inverse(&a.element) + b.Exp(b, exp) + + return a.element.Equal(&b) + } + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + properties := gopter.NewProperties(parameters) + genA := gen() + properties.Property("inv == exp^-2", prop.ForAll(invMatchExp, genA)) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + + parameters.MinSuccessfulTests = 1 + properties = gopter.NewProperties(parameters) + properties.Property("inv(0) == 0", prop.ForAll(invMatchExp, ggen.OneConstOf(testPairElement{}))) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func mulByConstant(z *Element, c uint8) { + var y Element + y.SetUint64(uint64(c)) + z.Mul(z, &y) +} + +func TestElementMulByConstants(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + implemented := []uint8{0, 1, 2, 3, 5, 13} + properties.Property("mulByConstant", prop.ForAll( + func(a testPairElement) bool { + for _, c := range implemented { + var constant Element + constant.SetUint64(uint64(c)) + + b := a.element + b.Mul(&b, &constant) + + aa := a.element + mulByConstant(&aa, c) + + if !aa.Equal(&b) { + return false + } + } + + return true + }, + genA, + )) + + properties.Property("MulBy3(x) == Mul(x, 3)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(3) + + b := a.element + b.Mul(&b, &constant) + + MulBy3(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("MulBy5(x) == Mul(x, 5)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(5) + + b := a.element + b.Mul(&b, &constant) + + MulBy5(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("MulBy13(x) == Mul(x, 13)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(13) + + b := a.element + b.Mul(&b, &constant) + + MulBy13(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementLegendre(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("legendre should output same result than big.Int.Jacobi", prop.ForAll( + func(a testPairElement) bool { + return a.element.Legendre() == big.Jacobi(&a.bigint, Modulus()) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementBitLen(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( + func(a testPairElement) bool { + return a.element.fromMont().BitLen() == a.bigint.BitLen() + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementButterflies(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("butterfly0 == a -b; a +b", prop.ForAll( + func(a, b testPairElement) bool { + a0, b0 := a.element, b.element + + _butterflyGeneric(&a.element, &b.element) + Butterfly(&a0, &b0) + + return a.element.Equal(&a0) && b.element.Equal(&b0) + }, + genA, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementLexicographicallyLargest(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("element.Cmp should match LexicographicallyLargest output", prop.ForAll( + func(a testPairElement) bool { + var negA Element + negA.Neg(&a.element) + + cmpResult := a.element.Cmp(&negA) + lResult := a.element.LexicographicallyLargest() + + if lResult && cmpResult == 1 { + return true + } + if !lResult && cmpResult != 1 { + return true + } + return false + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementAdd(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Add: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Add(&a.element, &b.element) + a.element.Add(&a.element, &b.element) + b.element.Add(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Add: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Add(&a.element, &b.element) + + var d, e big.Int + d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Add(&a.element, &r) + d.Add(&a.bigint, &rb).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Add: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Add(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Add(&a, &b) + d.Add(&aBig, &bBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Add failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSub(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Sub: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Sub(&a.element, &b.element) + a.element.Sub(&a.element, &b.element) + b.element.Sub(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Sub: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Sub(&a.element, &b.element) + + var d, e big.Int + d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Sub(&a.element, &r) + d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Sub: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Sub(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Sub(&a, &b) + d.Sub(&aBig, &bBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Sub failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementMul(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Mul: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Mul(&a.element, &b.element) + a.element.Mul(&a.element, &b.element) + b.element.Mul(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Mul: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Mul(&a.element, &b.element) + + var d, e big.Int + d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Mul(&a.element, &r) + d.Mul(&a.bigint, &rb).Mod(&d, Modulus()) + + // checking generic impl against asm path + var cGeneric Element + _mulGeneric(&cGeneric, &a.element, &r) + if !cGeneric.Equal(&c) { + // need to give context to failing error. + return false + } + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Mul: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Mul(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + properties.Property("Mul: assembly implementation must be consistent with generic one", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + c.Mul(&a.element, &b.element) + _mulGeneric(&d, &a.element, &b.element) + return c.Equal(&d) + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Mul(&a, &b) + d.Mul(&aBig, &bBig).Mod(&d, Modulus()) + + // checking asm against generic impl + var cGeneric Element + _mulGeneric(&cGeneric, &a, &b) + if !cGeneric.Equal(&c) { + t.Fatal("Mul failed special test values: asm and generic impl don't match") + } + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Mul failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementDiv(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Div: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Div(&a.element, &b.element) + a.element.Div(&a.element, &b.element) + b.element.Div(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Div: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Div(&a.element, &b.element) + + var d, e big.Int + d.ModInverse(&b.bigint, Modulus()) + d.Mul(&d, &a.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Div(&a.element, &r) + d.ModInverse(&rb, Modulus()) + d.Mul(&d, &a.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Div: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Div(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Div(&a, &b) + d.ModInverse(&bBig, Modulus()) + d.Mul(&d, &aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Div failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementExp(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Exp: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Exp(a.element, &b.bigint) + a.element.Exp(a.element, &b.bigint) + b.element.Exp(d, &b.bigint) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Exp: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Exp(a.element, &b.bigint) + + var d, e big.Int + d.Exp(&a.bigint, &b.bigint, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + r := testValues[i] + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Exp(a.element, &rb) + d.Exp(&a.bigint, &rb, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Exp: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Exp(a.element, &b.bigint) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Exp(a, &bBig) + d.Exp(&aBig, &bBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Exp failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSquare(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Square: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Square(&a.element) + a.element.Square(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Square: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Square(&a.element) + + var d, e big.Int + d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Square: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Square(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Square(&a) + + var d, e big.Int + d.Mul(&aBig, &aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Square failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementInverse(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Inverse: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Inverse(&a.element) + a.element.Inverse(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Inverse: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Inverse(&a.element) + + var d, e big.Int + d.ModInverse(&a.bigint, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Inverse: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Inverse(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Inverse(&a) + + var d, e big.Int + d.ModInverse(&aBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Inverse failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSqrt(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Sqrt: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + b := a.element + + b.Sqrt(&a.element) + a.element.Sqrt(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Sqrt: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Sqrt(&a.element) + + var d, e big.Int + d.ModSqrt(&a.bigint, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Sqrt: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Sqrt(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Sqrt(&a) + + var d, e big.Int + d.ModSqrt(&aBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Sqrt failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementDouble(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Double: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Double(&a.element) + a.element.Double(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Double: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Double(&a.element) + + var d, e big.Int + d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Double: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Double(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Double(&a) + + var d, e big.Int + d.Lsh(&aBig, 1).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Double failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementNeg(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Neg: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Neg(&a.element) + a.element.Neg(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Neg: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Neg(&a.element) + + var d, e big.Int + d.Neg(&a.bigint).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Neg: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Neg(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Neg(&a) + + var d, e big.Int + d.Neg(&aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Neg failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementFixedExp(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + var ( + _bLegendreExponentElement *big.Int + _bSqrtExponentElement *big.Int + ) + + _bLegendreExponentElement, _ = new(big.Int).SetString("3c000000", 16) + const sqrtExponentElement = "7" + _bSqrtExponentElement, _ = new(big.Int).SetString(sqrtExponentElement, 16) + + genA := gen() + + properties.Property(fmt.Sprintf("expBySqrtExp must match Exp(%s)", sqrtExponentElement), prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.expBySqrtExp(c) + d.Exp(d, _bSqrtExponentElement) + return c.Equal(&d) + }, + genA, + )) + + properties.Property("expByLegendreExp must match Exp(3c000000)", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.expByLegendreExp(c) + d.Exp(d, _bLegendreExponentElement) + return c.Equal(&d) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementHalve(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + var twoInv Element + twoInv.SetUint64(2) + twoInv.Inverse(&twoInv) + + properties.Property("z.Halve must match z / 2", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.Halve() + d.Mul(&d, &twoInv) + return c.Equal(&d) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func combineSelectionArguments(c int64, z int8) int { + if z%3 == 0 { + return 0 + } + return int(c) +} + +func TestElementSelect(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := genFull() + genB := genFull() + genC := ggen.Int64() //the condition + genZ := ggen.Int8() //to make zeros artificially more likely + + properties.Property("Select: must select correctly", prop.ForAll( + func(a, b Element, cond int64, z int8) bool { + condC := combineSelectionArguments(cond, z) + + var c Element + c.Select(condC, &a, &b) + + if condC == 0 { + return c.Equal(&a) + } + return c.Equal(&b) + }, + genA, + genB, + genC, + genZ, + )) + + properties.Property("Select: having the receiver as operand should output the same result", prop.ForAll( + func(a, b Element, cond int64, z int8) bool { + condC := combineSelectionArguments(cond, z) + + var c, d Element + d.Set(&a) + c.Select(condC, &a, &b) + a.Select(condC, &a, &b) + b.Select(condC, &d, &b) + return a.Equal(&b) && a.Equal(&c) && b.Equal(&c) + }, + genA, + genB, + genC, + genZ, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementSetInt64(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("z.SetInt64 must match z.SetString", prop.ForAll( + func(a testPairElement, v int64) bool { + c := a.element + d := a.element + + c.SetInt64(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, ggen.Int64(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementSetInterface(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genInt := ggen.Int + genInt8 := ggen.Int8 + genInt16 := ggen.Int16 + genInt32 := ggen.Int32 + genInt64 := ggen.Int64 + + genUint := ggen.UInt + genUint8 := ggen.UInt8 + genUint16 := ggen.UInt16 + genUint32 := ggen.UInt32 + genUint64 := ggen.UInt64 + + properties.Property("z.SetInterface must match z.SetString with int8", prop.ForAll( + func(a testPairElement, v int8) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt8(), + )) + + properties.Property("z.SetInterface must match z.SetString with int16", prop.ForAll( + func(a testPairElement, v int16) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt16(), + )) + + properties.Property("z.SetInterface must match z.SetString with int32", prop.ForAll( + func(a testPairElement, v int32) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt32(), + )) + + properties.Property("z.SetInterface must match z.SetString with int64", prop.ForAll( + func(a testPairElement, v int64) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt64(), + )) + + properties.Property("z.SetInterface must match z.SetString with int", prop.ForAll( + func(a testPairElement, v int) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint8", prop.ForAll( + func(a testPairElement, v uint8) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint8(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint16", prop.ForAll( + func(a testPairElement, v uint16) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint16(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint32", prop.ForAll( + func(a testPairElement, v uint32) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint32(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint64", prop.ForAll( + func(a testPairElement, v uint64) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint64(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint", prop.ForAll( + func(a testPairElement, v uint) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + + { + assert := require.New(t) + var e Element + r, err := e.SetInterface(nil) + assert.Nil(r) + assert.Error(err) + + var ptE *Element + var ptB *big.Int + + r, err = e.SetInterface(ptE) + assert.Nil(r) + assert.Error(err) + ptE = new(Element).SetOne() + r, err = e.SetInterface(ptE) + assert.NoError(err) + assert.True(r.IsOne()) + + r, err = e.SetInterface(ptB) + assert.Nil(r) + assert.Error(err) + + } +} + +func TestElementNegativeExp(t *testing.T) { + t.Parallel() + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("x⁻ᵏ == 1/xᵏ", prop.ForAll( + func(a, b testPairElement) bool { + + var nb, d, e big.Int + nb.Neg(&b.bigint) + + var c Element + c.Exp(a.element, &nb) + + d.Exp(&a.bigint, &nb, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementNewElement(t *testing.T) { + assert := require.New(t) + + t.Parallel() + + e := NewElement(1) + assert.True(e.IsOne()) + + e = NewElement(0) + assert.True(e.IsZero()) +} + +func TestElementBatchInvert(t *testing.T) { + assert := require.New(t) + + t.Parallel() + + // ensure batchInvert([x]) == invert(x) + for i := int64(-1); i <= 2; i++ { + var e, eInv Element + e.SetInt64(i) + eInv.Inverse(&e) + + a := []Element{e} + aInv := BatchInvert(a) + + assert.True(aInv[0].Equal(&eInv), "batchInvert != invert") + + } + + // test x * x⁻¹ == 1 + tData := [][]int64{ + {-1, 1, 2, 3}, + {0, -1, 1, 2, 3, 0}, + {0, -1, 1, 0, 2, 3, 0}, + {-1, 1, 0, 2, 3}, + {0, 0, 1}, + {1, 0, 0}, + {0, 0, 0}, + } + + for _, t := range tData { + a := make([]Element, len(t)) + for i := 0; i < len(a); i++ { + a[i].SetInt64(t[i]) + } + + aInv := BatchInvert(a) + + assert.True(len(aInv) == len(a)) + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + assert.True(aInv[i].IsZero(), "0⁻¹ != 0") + } else { + assert.True(a[i].Mul(&a[i], &aInv[i]).IsOne(), "x * x⁻¹ != 1") + } + } + } + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("batchInvert --> x * x⁻¹ == 1", prop.ForAll( + func(tp testPairElement, r uint8) bool { + + a := make([]Element, r) + if r != 0 { + a[0] = tp.element + + } + one := One() + for i := 1; i < len(a); i++ { + a[i].Add(&a[i-1], &one) + } + + aInv := BatchInvert(a) + + assert.True(len(aInv) == len(a)) + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + if !aInv[i].IsZero() { + return false + } + } else { + if !a[i].Mul(&a[i], &aInv[i]).IsOne() { + return false + } + } + } + return true + }, + genA, ggen.UInt8(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementFromMont(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Assembly implementation must be consistent with generic one", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.fromMont() + _fromMontGeneric(&d) + return c.Equal(&d) + }, + genA, + )) + + properties.Property("x.fromMont().toMont() == x", prop.ForAll( + func(a testPairElement) bool { + c := a.element + c.fromMont().toMont() + return c.Equal(&a.element) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementJSON(t *testing.T) { + assert := require.New(t) + + type S struct { + A Element + B [3]Element + C *Element + D *Element + } + + // encode to JSON + var s S + s.A.SetString("-1") + s.B[2].SetUint64(42) + s.D = new(Element).SetUint64(8000) + + encoded, err := json.Marshal(&s) + assert.NoError(err) + // we may need to adjust "42" and "8000" values for some moduli; see Text() method for more details. + formatValue := func(v int64) string { + var a big.Int + a.SetInt64(v) + a.Mod(&a, Modulus()) + const maxUint16 = 65535 + var aNeg big.Int + aNeg.Neg(&a).Mod(&aNeg, Modulus()) + if aNeg.Uint64() != 0 && aNeg.Uint64() <= maxUint16 { + return "-" + aNeg.Text(10) + } + return a.Text(10) + } + expected := fmt.Sprintf("{\"A\":%s,\"B\":[0,0,%s],\"C\":null,\"D\":%s}", formatValue(-1), formatValue(42), formatValue(8000)) + assert.Equal(expected, string(encoded)) + + // decode valid + var decoded S + err = json.Unmarshal([]byte(expected), &decoded) + assert.NoError(err) + + assert.Equal(s, decoded, "element -> json -> element round trip failed") + + // decode hex and string values + withHexValues := "{\"A\":\"-1\",\"B\":[0,\"0x00000\",\"0x2A\"],\"C\":null,\"D\":\"8000\"}" + + var decodedS S + err = json.Unmarshal([]byte(withHexValues), &decodedS) + assert.NoError(err) + + assert.Equal(s, decodedS, " json with strings -> element failed") + +} + +type testPairElement struct { + element Element + bigint big.Int +} + +func gen() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var g testPairElement + + g.element = Element{ + uint32(genParams.NextUint64()), + } + if qElement[0] != ^uint32(0) { + g.element[0] %= (qElement[0] + 1) + } + + for !g.element.smallerThanModulus() { + g.element = Element{ + uint32(genParams.NextUint64()), + } + if qElement[0] != ^uint32(0) { + g.element[0] %= (qElement[0] + 1) + } + } + + g.element.BigInt(&g.bigint) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genRandomFq(genParams *gopter.GenParameters) Element { + var g Element + + g = Element{ + uint32(genParams.NextUint64()), + } + + if qElement[0] != ^uint32(0) { + g[0] %= (qElement[0] + 1) + } + + for !g.smallerThanModulus() { + g = Element{ + uint32(genParams.NextUint64()), + } + if qElement[0] != ^uint32(0) { + g[0] %= (qElement[0] + 1) + } + } + + return g +} + +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + + var carry uint32 + a[0], _ = bits.Add32(a[0], qElement[0], carry) + + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} + +func genElement() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + a := genRandomFq(genParams) + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} diff --git a/field/babybear/vector.go b/field/babybear/vector.go new file mode 100644 index 0000000000..19a57ac10d --- /dev/null +++ b/field/babybear/vector.go @@ -0,0 +1,303 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "runtime" + "strings" + "sync" + "sync/atomic" + "unsafe" +) + +// Vector represents a slice of Element. +// +// It implements the following interfaces: +// - Stringer +// - io.WriterTo +// - io.ReaderFrom +// - encoding.BinaryMarshaler +// - encoding.BinaryUnmarshaler +// - sort.Interface +type Vector []Element + +// MarshalBinary implements encoding.BinaryMarshaler +func (vector *Vector) MarshalBinary() (data []byte, err error) { + var buf bytes.Buffer + + if _, err = vector.WriteTo(&buf); err != nil { + return + } + return buf.Bytes(), nil +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler +func (vector *Vector) UnmarshalBinary(data []byte) error { + r := bytes.NewReader(data) + _, err := vector.ReadFrom(r) + return err +} + +// WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. +// Length of the vector is encoded as a uint32 on the first 4 bytes. +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { + // encode slice length + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { + return 0, err + } + + n := int64(4) + + var buf [Bytes]byte + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) + m, err := w.Write(buf[:]) + n += int64(m) + if err != nil { + return n, err + } + } + return n, nil +} + +// AsyncReadFrom reads a vector of big endian encoded Element. +// Length of the vector must be encoded as a uint32 on the first 4 bytes. +// It consumes the needed bytes from the reader and returns the number of bytes read and an error if any. +// It also returns a channel that will be closed when the validation is done. +// The validation consist of checking that the elements are smaller than the modulus, and +// converting them to montgomery form. +func (vector *Vector) AsyncReadFrom(r io.Reader) (int64, error, chan error) { + chErr := make(chan error, 1) + var buf [Bytes]byte + if read, err := io.ReadFull(r, buf[:4]); err != nil { + close(chErr) + return int64(read), err, chErr + } + sliceLen := binary.BigEndian.Uint32(buf[:4]) + + n := int64(4) + (*vector) = make(Vector, sliceLen) + if sliceLen == 0 { + close(chErr) + return n, nil, chErr + } + + bSlice := unsafe.Slice((*byte)(unsafe.Pointer(&(*vector)[0])), sliceLen*Bytes) + read, err := io.ReadFull(r, bSlice) + n += int64(read) + if err != nil { + close(chErr) + return n, err, chErr + } + + go func() { + var cptErrors uint64 + // process the elements in parallel + execute(int(sliceLen), func(start, end int) { + + var z Element + for i := start; i < end; i++ { + // we have to set vector[i] + bstart := i * Bytes + bend := bstart + Bytes + b := bSlice[bstart:bend] + z[0] = binary.BigEndian.Uint32(b[0:4]) + + if !z.smallerThanModulus() { + atomic.AddUint64(&cptErrors, 1) + return + } + z.toMont() + (*vector)[i] = z + } + }) + + if cptErrors > 0 { + chErr <- fmt.Errorf("async read: %d elements failed validation", cptErrors) + } + close(chErr) + }() + return n, nil, chErr +} + +// ReadFrom implements io.ReaderFrom and reads a vector of big endian encoded Element. +// Length of the vector must be encoded as a uint32 on the first 4 bytes. +func (vector *Vector) ReadFrom(r io.Reader) (int64, error) { + + var buf [Bytes]byte + if read, err := io.ReadFull(r, buf[:4]); err != nil { + return int64(read), err + } + sliceLen := binary.BigEndian.Uint32(buf[:4]) + + n := int64(4) + (*vector) = make(Vector, sliceLen) + + for i := 0; i < int(sliceLen); i++ { + read, err := io.ReadFull(r, buf[:]) + n += int64(read) + if err != nil { + return n, err + } + (*vector)[i], err = BigEndian.Element(&buf) + if err != nil { + return n, err + } + } + + return n, nil +} + +// String implements fmt.Stringer interface +func (vector Vector) String() string { + var sbb strings.Builder + sbb.WriteByte('[') + for i := 0; i < len(vector); i++ { + sbb.WriteString(vector[i].String()) + if i != len(vector)-1 { + sbb.WriteByte(',') + } + } + sbb.WriteByte(']') + return sbb.String() +} + +// Len is the number of elements in the collection. +func (vector Vector) Len() int { + return len(vector) +} + +// Less reports whether the element with +// index i should sort before the element with index j. +func (vector Vector) Less(i, j int) bool { + return vector[i].Cmp(&vector[j]) == -1 +} + +// Swap swaps the elements with indexes i and j. +func (vector Vector) Swap(i, j int) { + vector[i], vector[j] = vector[j], vector[i] +} + +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + +// TODO @gbotrel make a public package out of that. +// execute executes the work function in parallel. +// this is copy paste from internal/parallel/parallel.go +// as we don't want to generate code importing internal/ +func execute(nbIterations int, work func(int, int), maxCpus ...int) { + + nbTasks := runtime.NumCPU() + if len(maxCpus) == 1 { + nbTasks = maxCpus[0] + if nbTasks < 1 { + nbTasks = 1 + } else if nbTasks > 512 { + nbTasks = 512 + } + } + + if nbTasks == 1 { + // no go routines + work(0, nbIterations) + return + } + + nbIterationsPerCpus := nbIterations / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = nbIterations + } + + var wg sync.WaitGroup + + extraTasks := nbIterations - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + go func() { + work(_start, _end) + wg.Done() + }() + } + + wg.Wait() +} diff --git a/field/babybear/vector_purego.go b/field/babybear/vector_purego.go new file mode 100644 index 0000000000..8843280543 --- /dev/null +++ b/field/babybear/vector_purego.go @@ -0,0 +1,54 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/field/babybear/vector_test.go b/field/babybear/vector_test.go new file mode 100644 index 0000000000..5d35c3e7bb --- /dev/null +++ b/field/babybear/vector_test.go @@ -0,0 +1,365 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +import ( + "bytes" + "fmt" + "github.com/stretchr/testify/require" + "os" + "reflect" + "sort" + "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" +) + +func TestVectorSort(t *testing.T) { + assert := require.New(t) + + v := make(Vector, 3) + v[0].SetUint64(2) + v[1].SetUint64(3) + v[2].SetUint64(1) + + sort.Sort(v) + + assert.Equal("[1,2,3]", v.String()) +} + +func TestVectorRoundTrip(t *testing.T) { + assert := require.New(t) + + v1 := make(Vector, 3) + v1[0].SetUint64(2) + v1[1].SetUint64(3) + v1[2].SetUint64(1) + + b, err := v1.MarshalBinary() + assert.NoError(err) + + var v2, v3 Vector + + err = v2.UnmarshalBinary(b) + assert.NoError(err) + + err = v3.unmarshalBinaryAsync(b) + assert.NoError(err) + + assert.True(reflect.DeepEqual(v1, v2)) + assert.True(reflect.DeepEqual(v3, v2)) +} + +func TestVectorEmptyRoundTrip(t *testing.T) { + assert := require.New(t) + + v1 := make(Vector, 0) + + b, err := v1.MarshalBinary() + assert.NoError(err) + + var v2, v3 Vector + + err = v2.UnmarshalBinary(b) + assert.NoError(err) + + err = v3.unmarshalBinaryAsync(b) + assert.NoError(err) + + assert.True(reflect.DeepEqual(v1, v2)) + assert.True(reflect.DeepEqual(v3, v2)) +} + +func (vector *Vector) unmarshalBinaryAsync(data []byte) error { + r := bytes.NewReader(data) + _, err, chErr := vector.AsyncReadFrom(r) + if err != nil { + return err + } + return <-chErr +} + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + uint32(genParams.NextUint64()), + } + if qElement[0] != ^uint32(0) { + mixer[0] %= (qElement[0] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + uint32(genParams.NextUint64()), + } + if qElement[0] != ^uint32(0) { + mixer[0] %= (qElement[0] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} diff --git a/field/generator/internal/templates/element/base.go b/field/generator/internal/templates/element/base.go index 22c15e9af8..7c600941bc 100644 --- a/field/generator/internal/templates/element/base.go +++ b/field/generator/internal/templates/element/base.go @@ -33,7 +33,7 @@ import ( type {{.ElementName}} [{{.NbWords}}]{{$.Word.TypeLower}} const ( - Limbs = {{.NbWords}} // number of {{$.Word.TypeLower}} words needed to represent a {{.ElementName}} + Limbs = {{.NbWords}} // number of {{$.Word.BitSize}} bits words needed to represent a {{.ElementName}} Bits = {{.NbBits}} // number of bits needed to represent a {{.ElementName}} Bytes = {{.NbBytes}} // number of bytes needed to represent a {{.ElementName}} ) @@ -89,7 +89,7 @@ func New{{.ElementName}}(v uint64) {{.ElementName}} { z.Mul(&z, &rSquare) return z {{- else }} - z := {{.ElementName}}{ {{$.Word.TypeLower}}(v) } + z := {{.ElementName}}{ v } z.Mul(&z, &rSquare) return z {{- end}} @@ -103,10 +103,9 @@ func (z *{{.ElementName}}) SetUint64(v uint64) *{{.ElementName}} { return z.Mul(z, &rSquare) // z.toMont() {{- else }} // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form - *z = {{.ElementName}}{ {{$.Word.TypeLower}}(v) } + *z = {{.ElementName}}{ v } return z.Mul(z, &rSquare) // z.toMont() {{- end}} - } // SetInt64 sets z to v and returns z @@ -223,8 +222,8 @@ func (z *{{.ElementName}}) Equal(x *{{.ElementName}}) bool { } // NotEqual returns 0 if and only if z == x; constant-time -func (z *{{.ElementName}}) NotEqual(x *{{.ElementName}}) uint64 { -return uint64({{- range $i := reverse .NbWordsIndexesNoZero}}(z[{{$i}}] ^ x[{{$i}}]) | {{end}}(z[0] ^ x[0])) +func (z *{{.ElementName}}) NotEqual(x *{{.ElementName}}) {{.Word.TypeLower}} { +return {{- range $i := reverse .NbWordsIndexesNoZero}}(z[{{$i}}] ^ x[{{$i}}]) | {{end}}(z[0] ^ x[0]) } // IsZero returns z == 0 @@ -254,7 +253,11 @@ func (z *{{.ElementName}}) IsUint64() bool { // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *{{.ElementName}}) Uint64() uint64 { - return uint64(z.Bits()[0]) + {{- if eq .Word.BitSize 32}} + return uint64(z.Bits()[0]) + {{- else}} + return z.Bits()[0] + {{- end}} } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -336,9 +339,6 @@ func (z *{{.ElementName}}) SetRandom() (*{{.ElementName}}, error) { return nil, err } - {{- if eq $.Word.BitSize 32}} - // TODO @gbotrel check this is correct with 32bits words - {{- end}} // Clear unused bits in in the most significant byte to increase probability // that the candidate is < q. bytes[k-1] &= uint8(int(1<= q { - z[0] -= q - } + {{- if lt .NbBits 32}} + z[0] = (x[0] << 1) + if z[0] >= q { + z[0] -= q + } + {{- else}} + if x[0]&(1<<63) == (1 << 63) { + // if highest bit is set, then we have a carry to x + x, we shift and subtract q + z[0] = (x[0] << 1) - q + } else { + // highest bit is not set, but x + x can still be >= q + z[0] = (x[0] << 1) + if z[0] >= q { + z[0] -= q + } + } + {{- end}} + return z {{- else}} {{ $hasCarry := or (not $.NoCarry) (gt $.NbWords 1)}} {{- if $hasCarry}} @@ -480,8 +494,8 @@ func (z *{{.ElementName}}) Double( x *{{.ElementName}}) *{{.ElementName}} { {{- end}} {{ template "reduce" .}} - {{- end}} return z + {{- end}} } diff --git a/field/generator/internal/templates/element/conv.go b/field/generator/internal/templates/element/conv.go index 0557c0d3f1..17298d4bb0 100644 --- a/field/generator/internal/templates/element/conv.go +++ b/field/generator/internal/templates/element/conv.go @@ -205,19 +205,26 @@ func (z *{{.ElementName}}) SetBigInt(v *big.Int) *{{.ElementName}} { func (z *{{.ElementName}}) setBigInt(v *big.Int) *{{.ElementName}} { vBits := v.Bits() - if bits.UintSize == 64 { - for i := 0; i < len(vBits); i++ { - z[i] = {{$.Word.TypeLower}}(vBits[i]) - } - } else { - for i := 0; i < len(vBits); i++ { - if i%2 == 0 { - z[i/2] = {{$.Word.TypeLower}}(vBits[i]) - } else { - z[i/2] |= {{$.Word.TypeLower}}(vBits[i]) << 32 + {{- if eq .Word.BitSize 32}} + // we assume v < q, so even if big.Int words are on 64bits, we can safely cast them to 32bits + for i := 0; i < len(vBits); i++ { + z[i] = {{$.Word.TypeLower}}(vBits[i]) + } + {{- else}} + if bits.UintSize == 64 { + for i := 0; i < len(vBits); i++ { + z[i] = {{$.Word.TypeLower}}(vBits[i]) + } + } else { + for i := 0; i < len(vBits); i++ { + if i%2 == 0 { + z[i/2] = {{$.Word.TypeLower}}(vBits[i]) + } else { + z[i/2] |= {{$.Word.TypeLower}}(vBits[i]) << 32 + } } } - } + {{- end}} return z.toMont() } diff --git a/field/generator/internal/templates/element/tests.go b/field/generator/internal/templates/element/tests.go index 6d2baef32b..7008294565 100644 --- a/field/generator/internal/templates/element/tests.go +++ b/field/generator/internal/templates/element/tests.go @@ -1569,6 +1569,10 @@ type testPair{{.ElementName}} struct { bigint big.Int } +{{- $gen64 := "genParams.NextUint64()"}} +{{- if eq .Word.BitSize 32}} +{{- $gen64 = "uint32(genParams.NextUint64())"}} +{{- end}} func gen() gopter.Gen { return func(genParams *gopter.GenParameters) *gopter.GenResult { @@ -1576,7 +1580,7 @@ func gen() gopter.Gen { g.element = {{.ElementName}}{ {{- range $i := .NbWordsIndexesFull}} - {{$.Word.TypeLower}}(genParams.NextUint64()),{{end}} + {{$gen64}},{{end}} } if qElement[{{.NbWordsLastIndex}}] != ^{{$.Word.TypeLower}}(0) { g.element[{{.NbWordsLastIndex}}] %= (qElement[{{.NbWordsLastIndex}}] +1 ) @@ -1586,7 +1590,7 @@ func gen() gopter.Gen { for !g.element.smallerThanModulus() { g.element = {{.ElementName}}{ {{- range $i := .NbWordsIndexesFull}} - {{$.Word.TypeLower}}(genParams.NextUint64()),{{end}} + {{$gen64}},{{end}} } if qElement[{{.NbWordsLastIndex}}] != ^{{$.Word.TypeLower}}(0) { g.element[{{.NbWordsLastIndex}}] %= (qElement[{{.NbWordsLastIndex}}] +1 ) @@ -1604,7 +1608,7 @@ func genRandomFq(genParams *gopter.GenParameters) {{.ElementName}} { g = {{.ElementName}}{ {{- range $i := .NbWordsIndexesFull}} - {{$.Word.TypeLower}}(genParams.NextUint64()),{{end}} + {{$gen64}},{{end}} } if qElement[{{.NbWordsLastIndex}}] != ^{{$.Word.TypeLower}}(0) { @@ -1614,7 +1618,7 @@ func genRandomFq(genParams *gopter.GenParameters) {{.ElementName}} { for !g.smallerThanModulus() { g = {{.ElementName}}{ {{- range $i := .NbWordsIndexesFull}} - {{$.Word.TypeLower}}(genParams.NextUint64()),{{end}} + {{$gen64}},{{end}} } if qElement[{{.NbWordsLastIndex}}] != ^{{$.Word.TypeLower}}(0) { g[{{.NbWordsLastIndex}}] %= (qElement[{{.NbWordsLastIndex}}] +1 ) diff --git a/field/generator/internal/templates/element/tests_vector.go b/field/generator/internal/templates/element/tests_vector.go index 9bdadd353f..3da4c94080 100644 --- a/field/generator/internal/templates/element/tests_vector.go +++ b/field/generator/internal/templates/element/tests_vector.go @@ -325,12 +325,17 @@ func genMaxVector(size int) gopter.Gen { } } +{{- $gen64 := "genParams.NextUint64()"}} +{{- if eq .Word.BitSize 32}} +{{- $gen64 = "uint32(genParams.NextUint64())"}} +{{- end}} + func genVector(size int) gopter.Gen { return func(genParams *gopter.GenParameters) *gopter.GenResult { g := make(Vector, size) mixer := {{.ElementName}}{ {{- range $i := .NbWordsIndexesFull}} - {{$.Word.TypeLower}}(genParams.NextUint64()),{{end}} + {{$gen64}},{{end}} } if qElement[{{.NbWordsLastIndex}}] != ^{{$.Word.TypeLower}}(0) { mixer[{{.NbWordsLastIndex}}] %= (qElement[{{.NbWordsLastIndex}}] +1 ) @@ -340,7 +345,7 @@ func genVector(size int) gopter.Gen { for !mixer.smallerThanModulus() { mixer = {{.ElementName}}{ {{- range $i := .NbWordsIndexesFull}} - {{$.Word.TypeLower}}(genParams.NextUint64()),{{end}} + {{$gen64}},{{end}} } if qElement[{{.NbWordsLastIndex}}] != ^{{$.Word.TypeLower}}(0) { mixer[{{.NbWordsLastIndex}}] %= (qElement[{{.NbWordsLastIndex}}] +1 ) diff --git a/field/goldilocks/element.go b/field/goldilocks/element.go index 8dd4d69919..889085d94c 100644 --- a/field/goldilocks/element.go +++ b/field/goldilocks/element.go @@ -627,11 +627,11 @@ func (z *Element) Text(base int) string { zzNeg.Neg(z) zzNeg.fromMont() if zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { - return "-" + strconv.FormatUint(zzNeg[0], base) + return "-" + strconv.FormatUint(uint64(zzNeg[0]), base) } } zz := z.Bits() - return strconv.FormatUint(zz[0], base) + return strconv.FormatUint(uint64(zz[0]), base) } // BigInt sets and return z as a *big.Int @@ -749,7 +749,6 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/field/goldilocks/internal/main.go b/field/goldilocks/internal/main.go deleted file mode 100644 index 04d5d75305..0000000000 --- a/field/goldilocks/internal/main.go +++ /dev/null @@ -1,21 +0,0 @@ -package main - -import ( - "fmt" - - "github.com/consensys/gnark-crypto/field/generator" - "github.com/consensys/gnark-crypto/field/generator/config" -) - -//go:generate go run main.go -func main() { - const modulus = "0xFFFFFFFF00000001" - goldilocks, err := config.NewFieldConfig("goldilocks", "Element", modulus, true) - if err != nil { - panic(err) - } - if err := generator.GenerateFF(goldilocks, "..", "", ""); err != nil { - panic(err) - } - fmt.Println("successfully generated goldilocks field") -} diff --git a/field/internal/addchain/3c000000 b/field/internal/addchain/3c000000 new file mode 100644 index 0000000000000000000000000000000000000000..7ff1a13c6446abe67f721ba6b266f802910a054a GIT binary patch literal 206 zcmWlP-3q|~07lOlAuUN-Ys8WyNs_dd7vP3haOFzm1^m2)r=9KQ^qsF06`w$x=jZ+U zyxc6lSiE8Is%da~RALz%tfR-0Q2fejOb}qbx*ImC)xx%I3=Oek2fKE$XAk@KF*3r~ x7!wm5I>eD9Oigj(1gB1M<_zb~ap3}&E^*}w*RFBn2Dff;=MFP7%*}D{o1A}xuGb1CDe*uWY2vp$7$il$L=mjJh7zBWd QSQwa?7#LZAd?q#q072&rs{jB1 literal 0 HcmV?d00001 diff --git a/field/goldilocks/internal/addchain/7fffffff b/field/internal/addchain/7fffffff similarity index 66% rename from field/goldilocks/internal/addchain/7fffffff rename to field/internal/addchain/7fffffff index b281f667616cfe8d6232c88a89b25860bf2916a7..2702c2bd601e9bdec1d5ecace2b67f582b0c5adf 100644 GIT binary patch delta 63 zcmaFK_=3?~?0+K@BO`l2QGR++VlLzVCI-g;4Gav@^~{WnO#THR4kJ*3CnF03Bcm6P OWMKHe`G3$2=?0+*8BO`l2QGR++VlLzV76!)uO$-b&{~MVZ8JYYGKs-jE3QtBB21Z6N PAj!b+fAjyAiRK#td)5-e diff --git a/field/goldilocks/internal/addchain/7fffffff80000000 b/field/internal/addchain/7fffffff80000000 similarity index 79% rename from field/goldilocks/internal/addchain/7fffffff80000000 rename to field/internal/addchain/7fffffff80000000 index c1d158b8e6f129a758c61a95ae1c26ed84b13c1e..3f6891cde4118c5c70424386a705be833d3be08c 100644 GIT binary patch delta 65 zcmeBT?qs$Q``^gK$jBa0l%HOdn9KOTiGlHd0|SF}Ju@RClYaq-!w6L1$;iUM$mj(m Q85sUC2LEr`Xc@=|06YW`U;qFB delta 66 zcmeBV?qaqS``^sO$jBa0l%HOdn9KOTg@N&Z69a?H|3+p;MkfCP5RVb4!jqAOfsxS* RNHQ?|V+{V^ve6=t5ddq@5o`be diff --git a/field/internal/main.go b/field/internal/main.go new file mode 100644 index 0000000000..b495d394e5 --- /dev/null +++ b/field/internal/main.go @@ -0,0 +1,36 @@ +package main + +import ( + "fmt" + "path/filepath" + + "github.com/consensys/gnark-crypto/field/generator" + "github.com/consensys/gnark-crypto/field/generator/config" +) + +//go:generate go run main.go +func main() { + // generate the following fields + + type field struct { + name string + modulus string + } + + fields := []field{ + {"goldilocks", "0xFFFFFFFF00000001"}, + {"koalabear", "0x7f000001"}, // 2^31 - 2^24 + 1 ==> the cube map (x -> x^3) is an automorphism of the multiplicative group + {"babybear", "0x78000001"}, // 2^31 - 2^27 + 1 ==> 2-adicity 27 + } + + for _, f := range fields { + fc, err := config.NewFieldConfig(f.name, "Element", f.modulus, true) + if err != nil { + panic(err) + } + if err := generator.GenerateFF(fc, filepath.Join("..", f.name), "", ""); err != nil { + panic(err) + } + fmt.Println("successfully generated", f.name, "field") + } +} diff --git a/field/koalabear/element.go b/field/koalabear/element.go index 7e42842839..acb810492d 100644 --- a/field/koalabear/element.go +++ b/field/koalabear/element.go @@ -47,7 +47,7 @@ import ( type Element [1]uint32 const ( - Limbs = 1 // number of uint32 words needed to represent a Element + Limbs = 1 // number of 32 bits words needed to represent a Element Bits = 31 // number of bits needed to represent a Element Bytes = 4 // number of bytes needed to represent a Element ) @@ -97,7 +97,6 @@ func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{uint32(v % uint64(q0))} return z.Mul(z, &rSquare) // z.toMont() - } // SetInt64 sets z to v and returns z @@ -208,8 +207,8 @@ func (z *Element) Equal(x *Element) bool { } // NotEqual returns 0 if and only if z == x; constant-time -func (z *Element) NotEqual(x *Element) uint64 { - return uint64((z[0] ^ x[0])) +func (z *Element) NotEqual(x *Element) uint32 { + return (z[0] ^ x[0]) } // IsZero returns z == 0 @@ -300,7 +299,7 @@ func (z *Element) SetRandom() (*Element, error) { if _, err := io.ReadFull(rand.Reader, bytes[:k]); err != nil { return nil, err } - // TODO @gbotrel check this is correct with 32bits words + // Clear unused bits in in the most significant byte to increase probability // that the candidate is < q. bytes[k-1] &= uint8(int(1<= q { z[0] -= q } @@ -736,19 +735,9 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - - if bits.UintSize == 64 { - for i := 0; i < len(vBits); i++ { - z[i] = uint32(vBits[i]) - } - } else { - for i := 0; i < len(vBits); i++ { - if i%2 == 0 { - z[i/2] = uint32(vBits[i]) - } else { - z[i/2] |= uint32(vBits[i]) << 32 - } - } + // we assume v < q, so even if big.Int words are on 64bits, we can safely cast them to 32bits + for i := 0; i < len(vBits); i++ { + z[i] = uint32(vBits[i]) } return z.toMont() diff --git a/field/koalabear/element_exp.go b/field/koalabear/element_exp.go new file mode 100644 index 0000000000..ea8c27d6ca --- /dev/null +++ b/field/koalabear/element_exp.go @@ -0,0 +1,122 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package koalabear + +// expBySqrtExp is equivalent to z.Exp(x, 3f) +// +// uses github.com/mmcloughlin/addchain v0.4.0 to generate a shorter addition chain +func (z *Element) expBySqrtExp(x Element) *Element { + // addition chain: + // + // _10 = 2*1 + // _11 = 1 + _10 + // _1100 = _11 << 2 + // _1111 = _11 + _1100 + // _111100 = _1111 << 2 + // return _11 + _111100 + // + // Operations: 5 squares 3 multiplies + + // Allocate Temporaries. + var ( + t0 = new(Element) + ) + + // var t0 Element + // Step 1: z = x^0x2 + z.Square(&x) + + // Step 2: z = x^0x3 + z.Mul(&x, z) + + // Step 4: t0 = x^0xc + t0.Square(z) + for s := 1; s < 2; s++ { + t0.Square(t0) + } + + // Step 5: t0 = x^0xf + t0.Mul(z, t0) + + // Step 7: t0 = x^0x3c + for s := 0; s < 2; s++ { + t0.Square(t0) + } + + // Step 8: z = x^0x3f + z.Mul(z, t0) + + return z +} + +// expByLegendreExp is equivalent to z.Exp(x, 3f800000) +// +// uses github.com/mmcloughlin/addchain v0.4.0 to generate a shorter addition chain +func (z *Element) expByLegendreExp(x Element) *Element { + // addition chain: + // + // _10 = 2*1 + // _11 = 1 + _10 + // _110 = 2*_11 + // _111 = 1 + _110 + // _1110 = 2*_111 + // _1111 = 1 + _1110 + // _1111000 = _1111 << 3 + // _1111111 = _111 + _1111000 + // return _1111111 << 23 + // + // Operations: 29 squares 4 multiplies + + // Allocate Temporaries. + var ( + t0 = new(Element) + ) + + // var t0 Element + // Step 1: z = x^0x2 + z.Square(&x) + + // Step 2: z = x^0x3 + z.Mul(&x, z) + + // Step 3: z = x^0x6 + z.Square(z) + + // Step 4: z = x^0x7 + z.Mul(&x, z) + + // Step 5: t0 = x^0xe + t0.Square(z) + + // Step 6: t0 = x^0xf + t0.Mul(&x, t0) + + // Step 9: t0 = x^0x78 + for s := 0; s < 3; s++ { + t0.Square(t0) + } + + // Step 10: z = x^0x7f + z.Mul(z, t0) + + // Step 33: z = x^0x3f800000 + for s := 0; s < 23; s++ { + z.Square(z) + } + + return z +} diff --git a/field/koalabear/internal/main.go b/field/koalabear/internal/main.go deleted file mode 100644 index d59fc4e46f..0000000000 --- a/field/koalabear/internal/main.go +++ /dev/null @@ -1,21 +0,0 @@ -package main - -import ( - "fmt" - - "github.com/consensys/gnark-crypto/field/generator" - "github.com/consensys/gnark-crypto/field/generator/config" -) - -//go:generate go run main.go -func main() { - const modulus = "0x7f000001" // KoalaBear 2^31 - 2^24 + 1 - koalabear, err := config.NewFieldConfig("koalabear", "Element", modulus, true) - if err != nil { - panic(err) - } - if err := generator.GenerateFF(koalabear, "..", "", ""); err != nil { - panic(err) - } - fmt.Println("successfully generated koalabear field") -} diff --git a/field/koalabear/vector.go b/field/koalabear/vector.go new file mode 100644 index 0000000000..60fa5f6183 --- /dev/null +++ b/field/koalabear/vector.go @@ -0,0 +1,303 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package koalabear + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "runtime" + "strings" + "sync" + "sync/atomic" + "unsafe" +) + +// Vector represents a slice of Element. +// +// It implements the following interfaces: +// - Stringer +// - io.WriterTo +// - io.ReaderFrom +// - encoding.BinaryMarshaler +// - encoding.BinaryUnmarshaler +// - sort.Interface +type Vector []Element + +// MarshalBinary implements encoding.BinaryMarshaler +func (vector *Vector) MarshalBinary() (data []byte, err error) { + var buf bytes.Buffer + + if _, err = vector.WriteTo(&buf); err != nil { + return + } + return buf.Bytes(), nil +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler +func (vector *Vector) UnmarshalBinary(data []byte) error { + r := bytes.NewReader(data) + _, err := vector.ReadFrom(r) + return err +} + +// WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. +// Length of the vector is encoded as a uint32 on the first 4 bytes. +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { + // encode slice length + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { + return 0, err + } + + n := int64(4) + + var buf [Bytes]byte + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) + m, err := w.Write(buf[:]) + n += int64(m) + if err != nil { + return n, err + } + } + return n, nil +} + +// AsyncReadFrom reads a vector of big endian encoded Element. +// Length of the vector must be encoded as a uint32 on the first 4 bytes. +// It consumes the needed bytes from the reader and returns the number of bytes read and an error if any. +// It also returns a channel that will be closed when the validation is done. +// The validation consist of checking that the elements are smaller than the modulus, and +// converting them to montgomery form. +func (vector *Vector) AsyncReadFrom(r io.Reader) (int64, error, chan error) { + chErr := make(chan error, 1) + var buf [Bytes]byte + if read, err := io.ReadFull(r, buf[:4]); err != nil { + close(chErr) + return int64(read), err, chErr + } + sliceLen := binary.BigEndian.Uint32(buf[:4]) + + n := int64(4) + (*vector) = make(Vector, sliceLen) + if sliceLen == 0 { + close(chErr) + return n, nil, chErr + } + + bSlice := unsafe.Slice((*byte)(unsafe.Pointer(&(*vector)[0])), sliceLen*Bytes) + read, err := io.ReadFull(r, bSlice) + n += int64(read) + if err != nil { + close(chErr) + return n, err, chErr + } + + go func() { + var cptErrors uint64 + // process the elements in parallel + execute(int(sliceLen), func(start, end int) { + + var z Element + for i := start; i < end; i++ { + // we have to set vector[i] + bstart := i * Bytes + bend := bstart + Bytes + b := bSlice[bstart:bend] + z[0] = binary.BigEndian.Uint32(b[0:4]) + + if !z.smallerThanModulus() { + atomic.AddUint64(&cptErrors, 1) + return + } + z.toMont() + (*vector)[i] = z + } + }) + + if cptErrors > 0 { + chErr <- fmt.Errorf("async read: %d elements failed validation", cptErrors) + } + close(chErr) + }() + return n, nil, chErr +} + +// ReadFrom implements io.ReaderFrom and reads a vector of big endian encoded Element. +// Length of the vector must be encoded as a uint32 on the first 4 bytes. +func (vector *Vector) ReadFrom(r io.Reader) (int64, error) { + + var buf [Bytes]byte + if read, err := io.ReadFull(r, buf[:4]); err != nil { + return int64(read), err + } + sliceLen := binary.BigEndian.Uint32(buf[:4]) + + n := int64(4) + (*vector) = make(Vector, sliceLen) + + for i := 0; i < int(sliceLen); i++ { + read, err := io.ReadFull(r, buf[:]) + n += int64(read) + if err != nil { + return n, err + } + (*vector)[i], err = BigEndian.Element(&buf) + if err != nil { + return n, err + } + } + + return n, nil +} + +// String implements fmt.Stringer interface +func (vector Vector) String() string { + var sbb strings.Builder + sbb.WriteByte('[') + for i := 0; i < len(vector); i++ { + sbb.WriteString(vector[i].String()) + if i != len(vector)-1 { + sbb.WriteByte(',') + } + } + sbb.WriteByte(']') + return sbb.String() +} + +// Len is the number of elements in the collection. +func (vector Vector) Len() int { + return len(vector) +} + +// Less reports whether the element with +// index i should sort before the element with index j. +func (vector Vector) Less(i, j int) bool { + return vector[i].Cmp(&vector[j]) == -1 +} + +// Swap swaps the elements with indexes i and j. +func (vector Vector) Swap(i, j int) { + vector[i], vector[j] = vector[j], vector[i] +} + +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + +// TODO @gbotrel make a public package out of that. +// execute executes the work function in parallel. +// this is copy paste from internal/parallel/parallel.go +// as we don't want to generate code importing internal/ +func execute(nbIterations int, work func(int, int), maxCpus ...int) { + + nbTasks := runtime.NumCPU() + if len(maxCpus) == 1 { + nbTasks = maxCpus[0] + if nbTasks < 1 { + nbTasks = 1 + } else if nbTasks > 512 { + nbTasks = 512 + } + } + + if nbTasks == 1 { + // no go routines + work(0, nbIterations) + return + } + + nbIterationsPerCpus := nbIterations / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = nbIterations + } + + var wg sync.WaitGroup + + extraTasks := nbIterations - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + go func() { + work(_start, _end) + wg.Done() + }() + } + + wg.Wait() +} diff --git a/field/koalabear/vector_purego.go b/field/koalabear/vector_purego.go new file mode 100644 index 0000000000..71dc2cc0c3 --- /dev/null +++ b/field/koalabear/vector_purego.go @@ -0,0 +1,54 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package koalabear + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/field/koalabear/vector_test.go b/field/koalabear/vector_test.go new file mode 100644 index 0000000000..b91728a8ac --- /dev/null +++ b/field/koalabear/vector_test.go @@ -0,0 +1,365 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package koalabear + +import ( + "bytes" + "fmt" + "github.com/stretchr/testify/require" + "os" + "reflect" + "sort" + "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" +) + +func TestVectorSort(t *testing.T) { + assert := require.New(t) + + v := make(Vector, 3) + v[0].SetUint64(2) + v[1].SetUint64(3) + v[2].SetUint64(1) + + sort.Sort(v) + + assert.Equal("[1,2,3]", v.String()) +} + +func TestVectorRoundTrip(t *testing.T) { + assert := require.New(t) + + v1 := make(Vector, 3) + v1[0].SetUint64(2) + v1[1].SetUint64(3) + v1[2].SetUint64(1) + + b, err := v1.MarshalBinary() + assert.NoError(err) + + var v2, v3 Vector + + err = v2.UnmarshalBinary(b) + assert.NoError(err) + + err = v3.unmarshalBinaryAsync(b) + assert.NoError(err) + + assert.True(reflect.DeepEqual(v1, v2)) + assert.True(reflect.DeepEqual(v3, v2)) +} + +func TestVectorEmptyRoundTrip(t *testing.T) { + assert := require.New(t) + + v1 := make(Vector, 0) + + b, err := v1.MarshalBinary() + assert.NoError(err) + + var v2, v3 Vector + + err = v2.UnmarshalBinary(b) + assert.NoError(err) + + err = v3.unmarshalBinaryAsync(b) + assert.NoError(err) + + assert.True(reflect.DeepEqual(v1, v2)) + assert.True(reflect.DeepEqual(v3, v2)) +} + +func (vector *Vector) unmarshalBinaryAsync(data []byte) error { + r := bytes.NewReader(data) + _, err, chErr := vector.AsyncReadFrom(r) + if err != nil { + return err + } + return <-chErr +} + +func TestVectorOps(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = 2 + } else { + parameters.MinSuccessfulTests = 10 + } + properties := gopter.NewProperties(parameters) + + addVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Add(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Add(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + subVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + c.Sub(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Sub(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + scalarMulVector := func(a Vector, b Element) bool { + c := make(Vector, len(a)) + c.ScalarMul(a, &b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sumVector := func(a Vector) bool { + var sum Element + computed := a.Sum() + for i := 0; i < len(a); i++ { + sum.Add(&sum, &a[i]) + } + + return sum.Equal(&computed) + } + + innerProductVector := func(a, b Vector) bool { + computed := a.InnerProduct(b) + var innerProduct Element + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + innerProduct.Add(&innerProduct, &tmp) + } + + return innerProduct.Equal(&computed) + } + + mulVector := func(a, b Vector) bool { + c := make(Vector, len(a)) + a[0].SetUint64(0x24) + b[0].SetUint64(0x42) + c.Mul(a, b) + + for i := 0; i < len(a); i++ { + var tmp Element + tmp.Mul(&a[i], &b[i]) + if !tmp.Equal(&c[i]) { + return false + } + } + return true + } + + sizes := []int{1, 2, 3, 4, 8, 9, 15, 16, 509, 510, 511, 512, 513, 514} + type genPair struct { + g1, g2 gopter.Gen + label string + } + + for _, size := range sizes { + generators := []genPair{ + {genZeroVector(size), genZeroVector(size), "zero vectors"}, + {genMaxVector(size), genMaxVector(size), "max vectors"}, + {genVector(size), genVector(size), "random vectors"}, + {genVector(size), genZeroVector(size), "random and zero vectors"}, + } + for _, gp := range generators { + properties.Property(fmt.Sprintf("vector addition %d - %s", size, gp.label), prop.ForAll( + addVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector subtraction %d - %s", size, gp.label), prop.ForAll( + subVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector scalar multiplication %d - %s", size, gp.label), prop.ForAll( + scalarMulVector, + gp.g1, + genElement(), + )) + + properties.Property(fmt.Sprintf("vector sum %d - %s", size, gp.label), prop.ForAll( + sumVector, + gp.g1, + )) + + properties.Property(fmt.Sprintf("vector inner product %d - %s", size, gp.label), prop.ForAll( + innerProductVector, + gp.g1, + gp.g2, + )) + + properties.Property(fmt.Sprintf("vector multiplication %d - %s", size, gp.label), prop.ForAll( + mulVector, + gp.g1, + gp.g2, + )) + } + } + + properties.TestingRun(t, gopter.NewFormatedReporter(false, 260, os.Stdout)) +} + +func BenchmarkVectorOps(b *testing.B) { + // note; to benchmark against "no asm" version, use the following + // build tag: -tags purego + const N = 1 << 24 + a1 := make(Vector, N) + b1 := make(Vector, N) + c1 := make(Vector, N) + var mixer Element + mixer.SetRandom() + for i := 1; i < N; i++ { + a1[i-1].SetUint64(uint64(i)). + Mul(&a1[i-1], &mixer) + b1[i-1].SetUint64(^uint64(i)). + Mul(&b1[i-1], &mixer) + } + + for n := 1 << 4; n <= N; n <<= 1 { + b.Run(fmt.Sprintf("add %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Add(_a, _b) + } + }) + + b.Run(fmt.Sprintf("sub %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Sub(_a, _b) + } + }) + + b.Run(fmt.Sprintf("scalarMul %d", n), func(b *testing.B) { + _a := a1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.ScalarMul(_a, &mixer) + } + }) + + b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { + _a := a1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.Sum() + } + }) + + b.Run(fmt.Sprintf("innerProduct %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = _a.InnerProduct(_b) + } + }) + + b.Run(fmt.Sprintf("mul %d", n), func(b *testing.B) { + _a := a1[:n] + _b := b1[:n] + _c := c1[:n] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _c.Mul(_a, _b) + } + }) + } +} + +func genZeroVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genMaxVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + + qMinusOne := qElement + qMinusOne[0]-- + + for i := 0; i < size; i++ { + g[i] = qMinusOne + } + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genVector(size int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + g := make(Vector, size) + mixer := Element{ + uint32(genParams.NextUint64()), + } + if qElement[0] != ^uint32(0) { + mixer[0] %= (qElement[0] + 1) + } + + for !mixer.smallerThanModulus() { + mixer = Element{ + uint32(genParams.NextUint64()), + } + if qElement[0] != ^uint32(0) { + mixer[0] %= (qElement[0] + 1) + } + } + + for i := 1; i <= size; i++ { + g[i-1].SetUint64(uint64(i)). + Mul(&g[i-1], &mixer) + } + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} From 6ae75a9a428f143b618d7a1660d6995606c7ba8d Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 25 Nov 2024 14:23:33 -0600 Subject: [PATCH 27/74] fix: restore line return to minimize diff --- ecc/bls12-377/fp/element.go | 1 + ecc/bls12-377/fr/element.go | 1 + ecc/bls12-381/fp/element.go | 1 + ecc/bls12-381/fr/element.go | 1 + ecc/bls24-315/fp/element.go | 1 + ecc/bls24-315/fr/element.go | 1 + ecc/bls24-317/fp/element.go | 1 + ecc/bls24-317/fr/element.go | 1 + ecc/bn254/fp/element.go | 1 + ecc/bn254/fr/element.go | 1 + ecc/bw6-633/fp/element.go | 1 + ecc/bw6-633/fr/element.go | 1 + ecc/bw6-761/fp/element.go | 1 + ecc/bw6-761/fr/element.go | 1 + ecc/secp256k1/fp/element.go | 1 + ecc/secp256k1/fr/element.go | 1 + ecc/stark-curve/fp/element.go | 1 + ecc/stark-curve/fr/element.go | 1 + field/generator/internal/templates/element/conv.go | 1 + field/goldilocks/element.go | 1 + 20 files changed, 20 insertions(+) diff --git a/ecc/bls12-377/fp/element.go b/ecc/bls12-377/fp/element.go index 5266ace480..393f45744d 100644 --- a/ecc/bls12-377/fp/element.go +++ b/ecc/bls12-377/fp/element.go @@ -1119,6 +1119,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bls12-377/fr/element.go b/ecc/bls12-377/fr/element.go index d71abadc0d..af277e8bb1 100644 --- a/ecc/bls12-377/fr/element.go +++ b/ecc/bls12-377/fr/element.go @@ -960,6 +960,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bls12-381/fp/element.go b/ecc/bls12-381/fp/element.go index 5808de88cd..f0bcfe51bc 100644 --- a/ecc/bls12-381/fp/element.go +++ b/ecc/bls12-381/fp/element.go @@ -1119,6 +1119,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bls12-381/fr/element.go b/ecc/bls12-381/fr/element.go index f5f88277b8..dc38f08cd3 100644 --- a/ecc/bls12-381/fr/element.go +++ b/ecc/bls12-381/fr/element.go @@ -960,6 +960,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bls24-315/fp/element.go b/ecc/bls24-315/fp/element.go index 6625335bc9..4ab67695e3 100644 --- a/ecc/bls24-315/fp/element.go +++ b/ecc/bls24-315/fp/element.go @@ -1035,6 +1035,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bls24-315/fr/element.go b/ecc/bls24-315/fr/element.go index 7c297b79dc..abdb822acf 100644 --- a/ecc/bls24-315/fr/element.go +++ b/ecc/bls24-315/fr/element.go @@ -960,6 +960,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bls24-317/fp/element.go b/ecc/bls24-317/fp/element.go index a4a2ee282e..77818de479 100644 --- a/ecc/bls24-317/fp/element.go +++ b/ecc/bls24-317/fp/element.go @@ -1035,6 +1035,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bls24-317/fr/element.go b/ecc/bls24-317/fr/element.go index bf936ea4e9..3aefaebe62 100644 --- a/ecc/bls24-317/fr/element.go +++ b/ecc/bls24-317/fr/element.go @@ -960,6 +960,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bn254/fp/element.go b/ecc/bn254/fp/element.go index 2b207e73b7..25fcdb67cc 100644 --- a/ecc/bn254/fp/element.go +++ b/ecc/bn254/fp/element.go @@ -960,6 +960,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bn254/fr/element.go b/ecc/bn254/fr/element.go index eb95ff30e6..3650c954c5 100644 --- a/ecc/bn254/fr/element.go +++ b/ecc/bn254/fr/element.go @@ -960,6 +960,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bw6-633/fp/element.go b/ecc/bw6-633/fp/element.go index 2165d16e17..7656002f47 100644 --- a/ecc/bw6-633/fp/element.go +++ b/ecc/bw6-633/fp/element.go @@ -1515,6 +1515,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bw6-633/fr/element.go b/ecc/bw6-633/fr/element.go index f019a15673..8841cd342c 100644 --- a/ecc/bw6-633/fr/element.go +++ b/ecc/bw6-633/fr/element.go @@ -1035,6 +1035,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bw6-761/fp/element.go b/ecc/bw6-761/fp/element.go index 3c8dcc99aa..8cdd31218e 100644 --- a/ecc/bw6-761/fp/element.go +++ b/ecc/bw6-761/fp/element.go @@ -1749,6 +1749,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/bw6-761/fr/element.go b/ecc/bw6-761/fr/element.go index a11193a4de..6784bc911f 100644 --- a/ecc/bw6-761/fr/element.go +++ b/ecc/bw6-761/fr/element.go @@ -1119,6 +1119,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/secp256k1/fp/element.go b/ecc/secp256k1/fp/element.go index 382379a65d..73045a133c 100644 --- a/ecc/secp256k1/fp/element.go +++ b/ecc/secp256k1/fp/element.go @@ -988,6 +988,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/secp256k1/fr/element.go b/ecc/secp256k1/fr/element.go index 5b89009e92..e2f81b66b3 100644 --- a/ecc/secp256k1/fr/element.go +++ b/ecc/secp256k1/fr/element.go @@ -988,6 +988,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/stark-curve/fp/element.go b/ecc/stark-curve/fp/element.go index 990494c9f5..1c53dcb090 100644 --- a/ecc/stark-curve/fp/element.go +++ b/ecc/stark-curve/fp/element.go @@ -960,6 +960,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/ecc/stark-curve/fr/element.go b/ecc/stark-curve/fr/element.go index bd0d4aae0d..216e287ebb 100644 --- a/ecc/stark-curve/fr/element.go +++ b/ecc/stark-curve/fr/element.go @@ -960,6 +960,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) diff --git a/field/generator/internal/templates/element/conv.go b/field/generator/internal/templates/element/conv.go index 17298d4bb0..8b021d4a69 100644 --- a/field/generator/internal/templates/element/conv.go +++ b/field/generator/internal/templates/element/conv.go @@ -211,6 +211,7 @@ func (z *{{.ElementName}}) setBigInt(v *big.Int) *{{.ElementName}} { z[i] = {{$.Word.TypeLower}}(vBits[i]) } {{- else}} + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = {{$.Word.TypeLower}}(vBits[i]) diff --git a/field/goldilocks/element.go b/field/goldilocks/element.go index 889085d94c..e48950ae15 100644 --- a/field/goldilocks/element.go +++ b/field/goldilocks/element.go @@ -749,6 +749,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) From 18b73743d42ab6bc2e19b105a420cabf16607785 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 25 Nov 2024 14:34:28 -0600 Subject: [PATCH 28/74] feat: restore non-big int inverse --- field/babybear/element.go | 67 +++++++++++++++++-- .../internal/templates/element/inverse.go | 22 +++--- field/koalabear/element.go | 67 +++++++++++++++++-- 3 files changed, 136 insertions(+), 20 deletions(-) diff --git a/field/babybear/element.go b/field/babybear/element.go index b4f4fb4f23..a68d9f4db3 100644 --- a/field/babybear/element.go +++ b/field/babybear/element.go @@ -965,11 +965,68 @@ func (z *Element) Sqrt(x *Element) *Element { // Inverse z = x⁻¹ (mod q) // -// note: allocates a big.Int (math/big) +// if x == 0, sets and returns z = x func (z *Element) Inverse(x *Element) *Element { - var _xNonMont big.Int - x.BigInt(&_xNonMont) - _xNonMont.ModInverse(&_xNonMont, Modulus()) - z.SetBigInt(&_xNonMont) + // Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" + const q uint32 = q0 + if x.IsZero() { + z.SetZero() + return z + } + + var r, s, u, v uint32 + u = q + s = 1172168163 // s = r² + r = 0 + v = x[0] + + var carry, borrow uint32 + + for (u != 1) && (v != 1) { + for v&1 == 0 { + v >>= 1 + if s&1 == 0 { + s >>= 1 + } else { + s, carry = bits.Add32(s, q, 0) + s >>= 1 + if carry != 0 { + s |= (1 << 31) + } + } + } + for u&1 == 0 { + u >>= 1 + if r&1 == 0 { + r >>= 1 + } else { + r, carry = bits.Add32(r, q, 0) + r >>= 1 + if carry != 0 { + r |= (1 << 31) + } + } + } + if v >= u { + v -= u + s, borrow = bits.Sub32(s, r, 0) + if borrow == 1 { + s += q + } + } else { + u -= v + r, borrow = bits.Sub32(r, s, 0) + if borrow == 1 { + r += q + } + } + } + + if u == 1 { + z[0] = r + } else { + z[0] = s + } + return z } diff --git a/field/generator/internal/templates/element/inverse.go b/field/generator/internal/templates/element/inverse.go index a9cd2ff13d..317f044b0b 100644 --- a/field/generator/internal/templates/element/inverse.go +++ b/field/generator/internal/templates/element/inverse.go @@ -20,25 +20,27 @@ if b != 0 { {{/* We use big.Int for Inverse for these type of moduli */}} {{if not $.UsingP20Inverse}} -{{- if and (eq .NbWords 1) (eq .Word.BitSize 64)}} +{{- if eq .NbWords 1}} // Inverse z = x⁻¹ (mod q) // // if x == 0, sets and returns z = x func (z *{{.ElementName}}) Inverse( x *{{.ElementName}}) *{{.ElementName}} { // Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" - const q uint64 = q0 + const q {{.Word.TypeLower}} = q0 if x.IsZero() { z.SetZero() return z } - var r,s,u,v uint64 + var r,s,u,v {{.Word.TypeLower}} u = q s = {{index .RSquare 0}} // s = r² r = 0 v = x[0] - var carry, borrow uint64 + var carry, borrow {{.Word.TypeLower}} + + {{- $bitSizeMinus1 := sub .Word.BitSize 1}} for (u != 1) && (v != 1){ for v&1 == 0 { @@ -46,10 +48,10 @@ func (z *{{.ElementName}}) Inverse( x *{{.ElementName}}) *{{.ElementName}} { if s&1 == 0 { s >>= 1 } else { - s, carry = bits.Add64(s, q, 0) + s, carry = bits.{{.Word.Add}}(s, q, 0) s >>= 1 if carry != 0 { - s |= (1 << 63) + s |= (1 << {{$bitSizeMinus1}}) } } } @@ -58,22 +60,22 @@ func (z *{{.ElementName}}) Inverse( x *{{.ElementName}}) *{{.ElementName}} { if r&1 == 0 { r >>= 1 } else { - r, carry = bits.Add64(r, q, 0) + r, carry = bits.{{.Word.Add}}(r, q, 0) r >>= 1 if carry != 0 { - r |= (1 << 63) + r |= (1 << {{$bitSizeMinus1}}) } } } if v >= u { v -= u - s, borrow = bits.Sub64(s, r, 0) + s, borrow = bits.{{.Word.Sub}}(s, r, 0) if borrow == 1 { s += q } } else { u -= v - r, borrow = bits.Sub64(r, s, 0) + r, borrow = bits.{{.Word.Sub}}(r, s, 0) if borrow == 1 { r += q } diff --git a/field/koalabear/element.go b/field/koalabear/element.go index acb810492d..f4fe7c4a24 100644 --- a/field/koalabear/element.go +++ b/field/koalabear/element.go @@ -965,11 +965,68 @@ func (z *Element) Sqrt(x *Element) *Element { // Inverse z = x⁻¹ (mod q) // -// note: allocates a big.Int (math/big) +// if x == 0, sets and returns z = x func (z *Element) Inverse(x *Element) *Element { - var _xNonMont big.Int - x.BigInt(&_xNonMont) - _xNonMont.ModInverse(&_xNonMont, Modulus()) - z.SetBigInt(&_xNonMont) + // Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" + const q uint32 = q0 + if x.IsZero() { + z.SetZero() + return z + } + + var r, s, u, v uint32 + u = q + s = 402124772 // s = r² + r = 0 + v = x[0] + + var carry, borrow uint32 + + for (u != 1) && (v != 1) { + for v&1 == 0 { + v >>= 1 + if s&1 == 0 { + s >>= 1 + } else { + s, carry = bits.Add32(s, q, 0) + s >>= 1 + if carry != 0 { + s |= (1 << 31) + } + } + } + for u&1 == 0 { + u >>= 1 + if r&1 == 0 { + r >>= 1 + } else { + r, carry = bits.Add32(r, q, 0) + r >>= 1 + if carry != 0 { + r |= (1 << 31) + } + } + } + if v >= u { + v -= u + s, borrow = bits.Sub32(s, r, 0) + if borrow == 1 { + s += q + } + } else { + u -= v + r, borrow = bits.Sub32(r, s, 0) + if borrow == 1 { + r += q + } + } + } + + if u == 1 { + z[0] = r + } else { + z[0] = s + } + return z } From f86ef7218b6972bfb468c2c947d96fe40c7f5bb5 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 25 Nov 2024 14:39:15 -0600 Subject: [PATCH 29/74] test: fix field config test to take word size into account --- field/generator/config/field_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/field/generator/config/field_test.go b/field/generator/config/field_test.go index a15370f728..27cc112e6f 100644 --- a/field/generator/config/field_test.go +++ b/field/generator/config/field_test.go @@ -43,7 +43,7 @@ func TestIntToMont(t *testing.T) { func(f *FieldConfig) (bool, error) { // test if using the same R i := big.NewInt(1) - i.Lsh(i, 64*uint(f.NbWords)) + i.Lsh(i, uint(f.Word.BitSize)*uint(f.NbWords)) *i = f.ToMont(*i) err := bigIntMatchUint64Slice(i, f.RSquare) From 2576e723f8df9ebe3bdcdd8c7ef6c46c0fe222b7 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Wed, 4 Dec 2024 14:24:56 -0600 Subject: [PATCH 30/74] feat: cleanup add template for field element --- field/babybear/element.go | 2 +- .../internal/templates/element/base.go | 39 ++++++++++++------- field/koalabear/element.go | 2 +- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/field/babybear/element.go b/field/babybear/element.go index a68d9f4db3..4841ad96aa 100644 --- a/field/babybear/element.go +++ b/field/babybear/element.go @@ -349,7 +349,7 @@ func (z *Element) fromMont() *Element { // Add z = x + y (mod q) func (z *Element) Add(x, y *Element) *Element { - z[0], _ = bits.Add32(x[0], y[0], 0) + z[0] = x[0] + y[0] if z[0] >= q { z[0] -= q } diff --git a/field/generator/internal/templates/element/base.go b/field/generator/internal/templates/element/base.go index 7c600941bc..57ffc01b19 100644 --- a/field/generator/internal/templates/element/base.go +++ b/field/generator/internal/templates/element/base.go @@ -414,20 +414,30 @@ func (z *{{.ElementName}}) fromMont() *{{.ElementName}} { // Add z = x + y (mod q) func (z *{{.ElementName}}) Add( x, y *{{.ElementName}}) *{{.ElementName}} { - {{ $hasCarry := or (not $.NoCarry) (gt $.NbWords 1)}} - {{- if $hasCarry}} - var carry {{$.Word.TypeLower}} - {{- end}} - {{- range $i := iterate 0 $.NbWords}} - {{- $hasCarry := or (not $.NoCarry) (lt $i $.NbWordsLastIndex)}} - z[{{$i}}], {{- if $hasCarry}}carry{{- else}}_{{- end}} = bits.{{$.Word.Add}}(x[{{$i}}], y[{{$i}}], {{- if eq $i 0}}0{{- else}}carry{{- end}}) - {{- end}} - - {{- if eq $.NbWords 1}} - if {{- if not .NoCarry}} carry != 0 ||{{- end }} z[0] >= q { - z[0] -= q - } + {{- if eq .NbWords 1}} + {{ $hasCarry := (not $.NoCarry)}} + {{- if $hasCarry}} + var carry {{$.Word.TypeLower}} + z[0], carry = bits.{{$.Word.Add}}(x[0], y[0], 0) + if carry != 0 || z[0] >= q { + z[0] -= q + } + return z + {{- else}} + z[0] = x[0] + y[0] + if z[0] >= q { + z[0] -= q + } + return z + {{- end}} {{- else}} + + var carry uint64 + {{- range $i := iterate 0 $.NbWords}} + {{- $hasCarry := or (not $.NoCarry) (lt $i $.NbWordsLastIndex)}} + z[{{$i}}], {{- if $hasCarry}}carry{{- else}}_{{- end}} = bits.Add64(x[{{$i}}], y[{{$i}}], {{- if eq $i 0}}0{{- else}}carry{{- end}}) + {{- end}} + {{- if not .NoCarry}} // if we overflowed the last addition, z >= q // if z >= q, z = z - q @@ -441,10 +451,9 @@ func (z *{{.ElementName}}) Add( x, y *{{.ElementName}}) *{{.ElementName}} { return z } {{- end}} - {{ template "reduce" .}} + return z {{- end}} - return z } diff --git a/field/koalabear/element.go b/field/koalabear/element.go index f4fe7c4a24..645d192926 100644 --- a/field/koalabear/element.go +++ b/field/koalabear/element.go @@ -349,7 +349,7 @@ func (z *Element) fromMont() *Element { // Add z = x + y (mod q) func (z *Element) Add(x, y *Element) *Element { - z[0], _ = bits.Add32(x[0], y[0], 0) + z[0] = x[0] + y[0] if z[0] >= q { z[0] -= q } From 93b6669fb3d05ca2c4fe59779d6196c8887d24d2 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Wed, 4 Dec 2024 14:46:19 -0600 Subject: [PATCH 31/74] feat: less ops in mul generic 31bits --- field/babybear/element_purego.go | 20 +++++++--------- .../internal/templates/element/mul_cios.go | 23 ++++++++++++------- field/goldilocks/element_purego.go | 6 ++--- field/koalabear/element_purego.go | 20 +++++++--------- 4 files changed, 33 insertions(+), 36 deletions(-) diff --git a/field/babybear/element_purego.go b/field/babybear/element_purego.go index a768309c08..9d0d199466 100644 --- a/field/babybear/element_purego.go +++ b/field/babybear/element_purego.go @@ -64,20 +64,18 @@ func (z *Element) Mul(x, y *Element) *Element { // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. - var r uint32 hi, lo := bits.Mul32(x[0], y[0]) if lo != 0 { hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow } m := lo * qInvNeg hi2, _ := bits.Mul32(m, q) - r, carry := bits.Add32(hi2, hi, 0) - if carry != 0 || r >= q { - // we need to reduce - r -= q + hi2 += hi + if hi2 >= q { + hi2 -= q } - z[0] = r + z[0] = hi2 return z } @@ -100,20 +98,18 @@ func (z *Element) Square(x *Element) *Element { // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. - var r uint32 hi, lo := bits.Mul32(x[0], x[0]) if lo != 0 { hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow } m := lo * qInvNeg hi2, _ := bits.Mul32(m, q) - r, carry := bits.Add32(hi2, hi, 0) - if carry != 0 || r >= q { - // we need to reduce - r -= q + hi2 += hi + if hi2 >= q { + hi2 -= q } - z[0] = r + z[0] = hi2 return z } diff --git a/field/generator/internal/templates/element/mul_cios.go b/field/generator/internal/templates/element/mul_cios.go index 7d19d86493..965024efda 100644 --- a/field/generator/internal/templates/element/mul_cios.go +++ b/field/generator/internal/templates/element/mul_cios.go @@ -110,20 +110,27 @@ const MulCIOS = ` // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. - var r {{$.all.Word.TypeLower}} hi, lo := bits.{{$.all.Word.Mul}}({{$.V1}}[0], {{$.V2}}[0]) if lo != 0 { hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow } m := lo * qInvNeg hi2, _ := bits.{{$.all.Word.Mul}}(m, q) - r, carry := bits.{{$.all.Word.Add}}(hi2, hi, 0) - - if carry != 0 || r >= q { - // we need to reduce - r -= q - } - z[0] = r + {{ $hasCarry := (not $.all.NoCarry)}} + {{- if $hasCarry}} + r, carry := bits.{{$.all.Word.Add}}(hi2, hi, 0) + if carry != 0 || r >= q { + // we need to reduce + r -= q + } + z[0] = r + {{- else}} + hi2 += hi + if hi2 >= q { + hi2 -= q + } + z[0] = hi2 + {{- end}} {{ end }} ` diff --git a/field/goldilocks/element_purego.go b/field/goldilocks/element_purego.go index f1090ab75f..c7b2647b18 100644 --- a/field/goldilocks/element_purego.go +++ b/field/goldilocks/element_purego.go @@ -62,15 +62,14 @@ func (z *Element) Mul(x, y *Element) *Element { // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. - var r uint64 hi, lo := bits.Mul64(x[0], y[0]) if lo != 0 { hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow } m := lo * qInvNeg hi2, _ := bits.Mul64(m, q) - r, carry := bits.Add64(hi2, hi, 0) + r, carry := bits.Add64(hi2, hi, 0) if carry != 0 || r >= q { // we need to reduce r -= q @@ -96,15 +95,14 @@ func (z *Element) Square(x *Element) *Element { // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. - var r uint64 hi, lo := bits.Mul64(x[0], x[0]) if lo != 0 { hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow } m := lo * qInvNeg hi2, _ := bits.Mul64(m, q) - r, carry := bits.Add64(hi2, hi, 0) + r, carry := bits.Add64(hi2, hi, 0) if carry != 0 || r >= q { // we need to reduce r -= q diff --git a/field/koalabear/element_purego.go b/field/koalabear/element_purego.go index ea28aa3cd5..628e2c4afb 100644 --- a/field/koalabear/element_purego.go +++ b/field/koalabear/element_purego.go @@ -64,20 +64,18 @@ func (z *Element) Mul(x, y *Element) *Element { // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. - var r uint32 hi, lo := bits.Mul32(x[0], y[0]) if lo != 0 { hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow } m := lo * qInvNeg hi2, _ := bits.Mul32(m, q) - r, carry := bits.Add32(hi2, hi, 0) - if carry != 0 || r >= q { - // we need to reduce - r -= q + hi2 += hi + if hi2 >= q { + hi2 -= q } - z[0] = r + z[0] = hi2 return z } @@ -100,20 +98,18 @@ func (z *Element) Square(x *Element) *Element { // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. - var r uint32 hi, lo := bits.Mul32(x[0], x[0]) if lo != 0 { hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow } m := lo * qInvNeg hi2, _ := bits.Mul32(m, q) - r, carry := bits.Add32(hi2, hi, 0) - if carry != 0 || r >= q { - // we need to reduce - r -= q + hi2 += hi + if hi2 >= q { + hi2 -= q } - z[0] = r + z[0] = hi2 return z } From f276812c4f50acab3c4f598e4135d13f251e3164 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 6 Dec 2024 20:41:34 +0000 Subject: [PATCH 32/74] style: cleaning PR --- ecc/bls12-377/fp/element.go | 14 ++-- ecc/bls12-377/fr/element.go | 10 +-- ecc/bls12-381/fp/element.go | 14 ++-- ecc/bls12-381/fr/element.go | 10 +-- ecc/bls24-315/fp/element.go | 12 ++-- ecc/bls24-315/fr/element.go | 10 +-- ecc/bls24-317/fp/element.go | 12 ++-- ecc/bls24-317/fr/element.go | 10 +-- ecc/bn254/fp/element.go | 10 +-- ecc/bn254/fr/element.go | 10 +-- ecc/bw6-633/fp/element.go | 22 +++--- ecc/bw6-633/fr/element.go | 12 ++-- ecc/bw6-761/fp/element.go | 26 +++---- ecc/bw6-761/fr/element.go | 14 ++-- ecc/secp256k1/fp/element.go | 10 +-- ecc/secp256k1/fr/element.go | 10 +-- ecc/stark-curve/fp/element.go | 10 +-- ecc/stark-curve/fr/element.go | 10 +-- field/babybear/arith.go | 48 ++++++------- field/babybear/element.go | 72 +++---------------- field/babybear/element_purego.go | 68 ++++-------------- field/babybear/element_test.go | 26 ------- field/generator/config/field_config.go | 10 +-- field/generator/generator.go | 2 +- .../internal/templates/element/arith.go | 48 ++++++------- .../internal/templates/element/base.go | 63 ++++++++-------- .../internal/templates/element/conv.go | 24 ++++--- .../internal/templates/element/mul_cios.go | 6 +- .../internal/templates/element/ops_purego.go | 29 +++++++- .../internal/templates/element/tests.go | 4 ++ field/goldilocks/element.go | 49 +------------ field/goldilocks/element_test.go | 26 ------- field/koalabear/arith.go | 48 ++++++------- field/koalabear/element.go | 72 +++---------------- field/koalabear/element_purego.go | 68 ++++-------------- field/koalabear/element_test.go | 26 ------- 36 files changed, 321 insertions(+), 594 deletions(-) diff --git a/ecc/bls12-377/fp/element.go b/ecc/bls12-377/fp/element.go index 393f45744d..1c135e16f9 100644 --- a/ecc/bls12-377/fp/element.go +++ b/ecc/bls12-377/fp/element.go @@ -54,12 +54,12 @@ const ( // Field modulus q const ( - q0 uint64 = 9586122913090633729 - q1 uint64 = 1660523435060625408 - q2 uint64 = 2230234197602682880 - q3 uint64 = 1883307231910630287 - q4 uint64 = 14284016967150029115 - q5 uint64 = 121098312706494698 + q0 = 9586122913090633729 + q1 = 1660523435060625408 + q2 = 2230234197602682880 + q3 = 1883307231910630287 + q4 = 14284016967150029115 + q5 = 121098312706494698 ) var qElement = Element{ @@ -83,7 +83,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 9586122913090633727 +const qInvNeg = 9586122913090633727 func init() { _modulus.SetString("1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001", 16) diff --git a/ecc/bls12-377/fr/element.go b/ecc/bls12-377/fr/element.go index af277e8bb1..93c0d1cc7d 100644 --- a/ecc/bls12-377/fr/element.go +++ b/ecc/bls12-377/fr/element.go @@ -54,10 +54,10 @@ const ( // Field modulus q const ( - q0 uint64 = 725501752471715841 - q1 uint64 = 6461107452199829505 - q2 uint64 = 6968279316240510977 - q3 uint64 = 1345280370688173398 + q0 = 725501752471715841 + q1 = 6461107452199829505 + q2 = 6968279316240510977 + q3 = 1345280370688173398 ) var qElement = Element{ @@ -79,7 +79,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 725501752471715839 +const qInvNeg = 725501752471715839 // mu = 2^288 / q needed for partial Barrett reduction const mu uint64 = 58893420465 diff --git a/ecc/bls12-381/fp/element.go b/ecc/bls12-381/fp/element.go index f0bcfe51bc..d5c6cefb2a 100644 --- a/ecc/bls12-381/fp/element.go +++ b/ecc/bls12-381/fp/element.go @@ -54,12 +54,12 @@ const ( // Field modulus q const ( - q0 uint64 = 13402431016077863595 - q1 uint64 = 2210141511517208575 - q2 uint64 = 7435674573564081700 - q3 uint64 = 7239337960414712511 - q4 uint64 = 5412103778470702295 - q5 uint64 = 1873798617647539866 + q0 = 13402431016077863595 + q1 = 2210141511517208575 + q2 = 7435674573564081700 + q3 = 7239337960414712511 + q4 = 5412103778470702295 + q5 = 1873798617647539866 ) var qElement = Element{ @@ -83,7 +83,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 9940570264628428797 +const qInvNeg = 9940570264628428797 func init() { _modulus.SetString("1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab", 16) diff --git a/ecc/bls12-381/fr/element.go b/ecc/bls12-381/fr/element.go index dc38f08cd3..5d218be808 100644 --- a/ecc/bls12-381/fr/element.go +++ b/ecc/bls12-381/fr/element.go @@ -54,10 +54,10 @@ const ( // Field modulus q const ( - q0 uint64 = 18446744069414584321 - q1 uint64 = 6034159408538082302 - q2 uint64 = 3691218898639771653 - q3 uint64 = 8353516859464449352 + q0 = 18446744069414584321 + q1 = 6034159408538082302 + q2 = 3691218898639771653 + q3 = 8353516859464449352 ) var qElement = Element{ @@ -79,7 +79,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 18446744069414584319 +const qInvNeg = 18446744069414584319 // mu = 2^288 / q needed for partial Barrett reduction const mu uint64 = 9484408045 diff --git a/ecc/bls24-315/fp/element.go b/ecc/bls24-315/fp/element.go index 4ab67695e3..361e0482c6 100644 --- a/ecc/bls24-315/fp/element.go +++ b/ecc/bls24-315/fp/element.go @@ -54,11 +54,11 @@ const ( // Field modulus q const ( - q0 uint64 = 8063698428123676673 - q1 uint64 = 4764498181658371330 - q2 uint64 = 16051339359738796768 - q3 uint64 = 15273757526516850351 - q4 uint64 = 342900304943437392 + q0 = 8063698428123676673 + q1 = 4764498181658371330 + q2 = 16051339359738796768 + q3 = 15273757526516850351 + q4 = 342900304943437392 ) var qElement = Element{ @@ -81,7 +81,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 8083954730842193919 +const qInvNeg = 8083954730842193919 func init() { _modulus.SetString("4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001", 16) diff --git a/ecc/bls24-315/fr/element.go b/ecc/bls24-315/fr/element.go index abdb822acf..2fe6f12cdc 100644 --- a/ecc/bls24-315/fr/element.go +++ b/ecc/bls24-315/fr/element.go @@ -54,10 +54,10 @@ const ( // Field modulus q const ( - q0 uint64 = 1860204336533995521 - q1 uint64 = 14466829657984787300 - q2 uint64 = 2737202078770428568 - q3 uint64 = 1832378743606059307 + q0 = 1860204336533995521 + q1 = 14466829657984787300 + q2 = 2737202078770428568 + q3 = 1832378743606059307 ) var qElement = Element{ @@ -79,7 +79,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 2184305180030271487 +const qInvNeg = 2184305180030271487 // mu = 2^288 / q needed for partial Barrett reduction const mu uint64 = 43237874697 diff --git a/ecc/bls24-317/fp/element.go b/ecc/bls24-317/fp/element.go index 77818de479..80bab83a71 100644 --- a/ecc/bls24-317/fp/element.go +++ b/ecc/bls24-317/fp/element.go @@ -54,11 +54,11 @@ const ( // Field modulus q const ( - q0 uint64 = 10182971180934965931 - q1 uint64 = 15488787195747417982 - q2 uint64 = 1628721857945875526 - q3 uint64 = 17478405972920225849 - q4 uint64 = 1177913551803681068 + q0 = 10182971180934965931 + q1 = 15488787195747417982 + q2 = 1628721857945875526 + q3 = 17478405972920225849 + q4 = 1177913551803681068 ) var qElement = Element{ @@ -81,7 +81,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 6176088765535387645 +const qInvNeg = 6176088765535387645 func init() { _modulus.SetString("1058ca226f60892cf28fc5a0b7f9d039169a61e684c73446d6f339e43424bf7e8d512e565dab2aab", 16) diff --git a/ecc/bls24-317/fr/element.go b/ecc/bls24-317/fr/element.go index 3aefaebe62..df5f9dc4cf 100644 --- a/ecc/bls24-317/fr/element.go +++ b/ecc/bls24-317/fr/element.go @@ -54,10 +54,10 @@ const ( // Field modulus q const ( - q0 uint64 = 17293822569102704641 - q1 uint64 = 2076695515679886970 - q2 uint64 = 15037686223802191177 - q3 uint64 = 4917809291258081218 + q0 = 17293822569102704641 + q1 = 2076695515679886970 + q2 = 15037686223802191177 + q3 = 4917809291258081218 ) var qElement = Element{ @@ -79,7 +79,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 17293822569102704639 +const qInvNeg = 17293822569102704639 // mu = 2^288 / q needed for partial Barrett reduction const mu uint64 = 16110458503 diff --git a/ecc/bn254/fp/element.go b/ecc/bn254/fp/element.go index 25fcdb67cc..395030094f 100644 --- a/ecc/bn254/fp/element.go +++ b/ecc/bn254/fp/element.go @@ -54,10 +54,10 @@ const ( // Field modulus q const ( - q0 uint64 = 4332616871279656263 - q1 uint64 = 10917124144477883021 - q2 uint64 = 13281191951274694749 - q3 uint64 = 3486998266802970665 + q0 = 4332616871279656263 + q1 = 10917124144477883021 + q2 = 13281191951274694749 + q3 = 3486998266802970665 ) var qElement = Element{ @@ -79,7 +79,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 9786893198990664585 +const qInvNeg = 9786893198990664585 // mu = 2^288 / q needed for partial Barrett reduction const mu uint64 = 22721021478 diff --git a/ecc/bn254/fr/element.go b/ecc/bn254/fr/element.go index 3650c954c5..f8b6c10840 100644 --- a/ecc/bn254/fr/element.go +++ b/ecc/bn254/fr/element.go @@ -54,10 +54,10 @@ const ( // Field modulus q const ( - q0 uint64 = 4891460686036598785 - q1 uint64 = 2896914383306846353 - q2 uint64 = 13281191951274694749 - q3 uint64 = 3486998266802970665 + q0 = 4891460686036598785 + q1 = 2896914383306846353 + q2 = 13281191951274694749 + q3 = 3486998266802970665 ) var qElement = Element{ @@ -79,7 +79,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 14042775128853446655 +const qInvNeg = 14042775128853446655 // mu = 2^288 / q needed for partial Barrett reduction const mu uint64 = 22721021478 diff --git a/ecc/bw6-633/fp/element.go b/ecc/bw6-633/fp/element.go index 7656002f47..6f2565eaf7 100644 --- a/ecc/bw6-633/fp/element.go +++ b/ecc/bw6-633/fp/element.go @@ -54,16 +54,16 @@ const ( // Field modulus q const ( - q0 uint64 = 15512955586897510413 - q1 uint64 = 4410884215886313276 - q2 uint64 = 15543556715411259941 - q3 uint64 = 9083347379620258823 - q4 uint64 = 13320134076191308873 - q5 uint64 = 9318693926755804304 - q6 uint64 = 5645674015335635503 - q7 uint64 = 12176845843281334983 - q8 uint64 = 18165857675053050549 - q9 uint64 = 82862755739295587 + q0 = 15512955586897510413 + q1 = 4410884215886313276 + q2 = 15543556715411259941 + q3 = 9083347379620258823 + q4 = 13320134076191308873 + q5 = 9318693926755804304 + q6 = 5645674015335635503 + q7 = 12176845843281334983 + q8 = 18165857675053050549 + q9 = 82862755739295587 ) var qElement = Element{ @@ -91,7 +91,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 13046692460116554043 +const qInvNeg = 13046692460116554043 func init() { _modulus.SetString("126633cc0f35f63fc1a174f01d72ab5a8fcd8c75d79d2c74e59769ad9bbda2f8152a6c0fadea490b8da9f5e83f57c497e0e8850edbda407d7b5ce7ab839c2253d369bd31147f73cd74916ea4570000d", 16) diff --git a/ecc/bw6-633/fr/element.go b/ecc/bw6-633/fr/element.go index 8841cd342c..b9cff65b7a 100644 --- a/ecc/bw6-633/fr/element.go +++ b/ecc/bw6-633/fr/element.go @@ -54,11 +54,11 @@ const ( // Field modulus q const ( - q0 uint64 = 8063698428123676673 - q1 uint64 = 4764498181658371330 - q2 uint64 = 16051339359738796768 - q3 uint64 = 15273757526516850351 - q4 uint64 = 342900304943437392 + q0 = 8063698428123676673 + q1 = 4764498181658371330 + q2 = 16051339359738796768 + q3 = 15273757526516850351 + q4 = 342900304943437392 ) var qElement = Element{ @@ -81,7 +81,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 8083954730842193919 +const qInvNeg = 8083954730842193919 func init() { _modulus.SetString("4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001", 16) diff --git a/ecc/bw6-761/fp/element.go b/ecc/bw6-761/fp/element.go index 8cdd31218e..a266ee9014 100644 --- a/ecc/bw6-761/fp/element.go +++ b/ecc/bw6-761/fp/element.go @@ -54,18 +54,18 @@ const ( // Field modulus q const ( - q0 uint64 = 17626244516597989515 - q1 uint64 = 16614129118623039618 - q2 uint64 = 1588918198704579639 - q3 uint64 = 10998096788944562424 - q4 uint64 = 8204665564953313070 - q5 uint64 = 9694500593442880912 - q6 uint64 = 274362232328168196 - q7 uint64 = 8105254717682411801 - q8 uint64 = 5945444129596489281 - q9 uint64 = 13341377791855249032 - q10 uint64 = 15098257552581525310 - q11 uint64 = 81882988782276106 + q0 = 17626244516597989515 + q1 = 16614129118623039618 + q2 = 1588918198704579639 + q3 = 10998096788944562424 + q4 = 8204665564953313070 + q5 = 9694500593442880912 + q6 = 274362232328168196 + q7 = 8105254717682411801 + q8 = 5945444129596489281 + q9 = 13341377791855249032 + q10 = 15098257552581525310 + q11 = 81882988782276106 ) var qElement = Element{ @@ -95,7 +95,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 744663313386281181 +const qInvNeg = 744663313386281181 func init() { _modulus.SetString("122e824fb83ce0ad187c94004faff3eb926186a81d14688528275ef8087be41707ba638e584e91903cebaff25b423048689c8ed12f9fd9071dcd3dc73ebff2e98a116c25667a8f8160cf8aeeaf0a437e6913e6870000082f49d00000000008b", 16) diff --git a/ecc/bw6-761/fr/element.go b/ecc/bw6-761/fr/element.go index 6784bc911f..dad9bbec1c 100644 --- a/ecc/bw6-761/fr/element.go +++ b/ecc/bw6-761/fr/element.go @@ -54,12 +54,12 @@ const ( // Field modulus q const ( - q0 uint64 = 9586122913090633729 - q1 uint64 = 1660523435060625408 - q2 uint64 = 2230234197602682880 - q3 uint64 = 1883307231910630287 - q4 uint64 = 14284016967150029115 - q5 uint64 = 121098312706494698 + q0 = 9586122913090633729 + q1 = 1660523435060625408 + q2 = 2230234197602682880 + q3 = 1883307231910630287 + q4 = 14284016967150029115 + q5 = 121098312706494698 ) var qElement = Element{ @@ -83,7 +83,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 9586122913090633727 +const qInvNeg = 9586122913090633727 func init() { _modulus.SetString("1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001", 16) diff --git a/ecc/secp256k1/fp/element.go b/ecc/secp256k1/fp/element.go index 73045a133c..cb1b02995e 100644 --- a/ecc/secp256k1/fp/element.go +++ b/ecc/secp256k1/fp/element.go @@ -54,10 +54,10 @@ const ( // Field modulus q const ( - q0 uint64 = 18446744069414583343 - q1 uint64 = 18446744073709551615 - q2 uint64 = 18446744073709551615 - q3 uint64 = 18446744073709551615 + q0 = 18446744069414583343 + q1 = 18446744073709551615 + q2 = 18446744073709551615 + q3 = 18446744073709551615 ) var qElement = Element{ @@ -79,7 +79,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 15580212934572586289 +const qInvNeg = 15580212934572586289 // mu = 2^288 / q needed for partial Barrett reduction const mu uint64 = 4294967296 diff --git a/ecc/secp256k1/fr/element.go b/ecc/secp256k1/fr/element.go index e2f81b66b3..d0684c6cc1 100644 --- a/ecc/secp256k1/fr/element.go +++ b/ecc/secp256k1/fr/element.go @@ -54,10 +54,10 @@ const ( // Field modulus q const ( - q0 uint64 = 13822214165235122497 - q1 uint64 = 13451932020343611451 - q2 uint64 = 18446744073709551614 - q3 uint64 = 18446744073709551615 + q0 = 13822214165235122497 + q1 = 13451932020343611451 + q2 = 18446744073709551614 + q3 = 18446744073709551615 ) var qElement = Element{ @@ -79,7 +79,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 5408259542528602431 +const qInvNeg = 5408259542528602431 // mu = 2^288 / q needed for partial Barrett reduction const mu uint64 = 4294967296 diff --git a/ecc/stark-curve/fp/element.go b/ecc/stark-curve/fp/element.go index 1c53dcb090..4b14ff1149 100644 --- a/ecc/stark-curve/fp/element.go +++ b/ecc/stark-curve/fp/element.go @@ -54,10 +54,10 @@ const ( // Field modulus q const ( - q0 uint64 = 1 - q1 uint64 = 0 - q2 uint64 = 0 - q3 uint64 = 576460752303423505 + q0 = 1 + q1 = 0 + q2 = 0 + q3 = 576460752303423505 ) var qElement = Element{ @@ -79,7 +79,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 18446744073709551615 +const qInvNeg = 18446744073709551615 // mu = 2^288 / q needed for partial Barrett reduction const mu uint64 = 137438953471 diff --git a/ecc/stark-curve/fr/element.go b/ecc/stark-curve/fr/element.go index 216e287ebb..2acb6e46b7 100644 --- a/ecc/stark-curve/fr/element.go +++ b/ecc/stark-curve/fr/element.go @@ -54,10 +54,10 @@ const ( // Field modulus q const ( - q0 uint64 = 2190616671734353199 - q1 uint64 = 13222870243701404210 - q2 uint64 = 18446744073709551615 - q3 uint64 = 576460752303423504 + q0 = 2190616671734353199 + q1 = 13222870243701404210 + q2 = 18446744073709551615 + q3 = 576460752303423504 ) var qElement = Element{ @@ -79,7 +79,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 13504954208620504625 +const qInvNeg = 13504954208620504625 // mu = 2^288 / q needed for partial Barrett reduction const mu uint64 = 137438953471 diff --git a/field/babybear/arith.go b/field/babybear/arith.go index 03c952e9fb..3dfd7e5ffe 100644 --- a/field/babybear/arith.go +++ b/field/babybear/arith.go @@ -21,40 +21,40 @@ import ( ) // madd0 hi = a*b + c (discards lo bits) -func madd0(a, b, c uint32) (hi uint32) { - var carry, lo uint32 - hi, lo = bits.Mul32(a, b) - _, carry = bits.Add32(lo, c, 0) - hi, _ = bits.Add32(hi, 0, carry) +func madd0(a, b, c uint64) (hi uint64) { + var carry, lo uint64 + hi, lo = bits.Mul64(a, b) + _, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) return } // madd1 hi, lo = a*b + c -func madd1(a, b, c uint32) (hi uint32, lo uint32) { - var carry uint32 - hi, lo = bits.Mul32(a, b) - lo, carry = bits.Add32(lo, c, 0) - hi, _ = bits.Add32(hi, 0, carry) +func madd1(a, b, c uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) return } // madd2 hi, lo = a*b + c + d -func madd2(a, b, c, d uint32) (hi uint32, lo uint32) { - var carry uint32 - hi, lo = bits.Mul32(a, b) - c, carry = bits.Add32(c, d, 0) - hi, _ = bits.Add32(hi, 0, carry) - lo, carry = bits.Add32(lo, c, 0) - hi, _ = bits.Add32(hi, 0, carry) +func madd2(a, b, c, d uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) return } -func madd3(a, b, c, d, e uint32) (hi uint32, lo uint32) { - var carry uint32 - hi, lo = bits.Mul32(a, b) - c, carry = bits.Add32(c, d, 0) - hi, _ = bits.Add32(hi, 0, carry) - lo, carry = bits.Add32(lo, c, 0) - hi, _ = bits.Add32(hi, e, carry) +func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, e, carry) return } diff --git a/field/babybear/element.go b/field/babybear/element.go index 4841ad96aa..63986904c7 100644 --- a/field/babybear/element.go +++ b/field/babybear/element.go @@ -54,8 +54,8 @@ const ( // Field modulus q const ( - q0 uint32 = 2013265921 - q uint32 = q0 + q0 = 2013265921 + q = q0 ) var qElement = Element{ @@ -74,7 +74,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint32 = 2013265919 +const qInvNeg = 2013265919 func init() { _modulus.SetString("78000001", 16) @@ -88,7 +88,7 @@ func init() { // v.SetUint64(...) func NewElement(v uint64) Element { z := Element{uint32(v % uint64(q0))} - z.Mul(&z, &rSquare) + z.toMont() return z } @@ -96,7 +96,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{uint32(v % uint64(q0))} - return z.Mul(z, &rSquare) // z.toMont() + return z.toMont() } // SetInt64 sets z to v and returns z @@ -393,64 +393,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } -// _mulGeneric is unoptimized textbook CIOS -// it is a fallback solution on x86 when ADX instruction set is not available -// and is used for testing purposes. -func _mulGeneric(z, x, y *Element) { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - var t [2]uint32 - var D uint32 - var m, C uint32 - // ----------------------------------- - // First loop - - C, t[0] = bits.Mul32(y[0], x[0]) - - t[1], D = bits.Add32(t[1], C, 0) - - // m = t[0]n'[0] mod W - m = t[0] * qInvNeg - - // ----------------------------------- - // Second loop - C = madd0(m, q0, t[0]) - - t[0], C = bits.Add32(t[1], C, 0) - t[1], _ = bits.Add32(0, D, C) - - if t[1] != 0 { - // we need to reduce, we have a result on 2 words - z[0], _ = bits.Sub32(t[0], q0, 0) - return - } - - // copy t into z - z[0] = t[0] - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - z[0] -= q - } -} - func _fromMontGeneric(z *Element) { - // the following lines implement z = z * 1 - // with a modified CIOS montgomery multiplication - // see Mul for algorithm documentation - { - // m = z[0]n'[0] mod W - m := z[0] * qInvNeg - C := madd0(m, q0, z[0]) - z[0] = C - } - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - z[0] -= q - } + z[0] = montReduce(uint64(z[0])) } func _reduceGeneric(z *Element) { @@ -576,7 +520,9 @@ var rSquare = Element{ // toMont converts z to Montgomery form // sets and returns z = z * r² func (z *Element) toMont() *Element { - return z.Mul(z, &rSquare) + const rBits = 32 + z[0] = uint32((uint64(z[0]) << rBits) % q) + return z } // String returns the decimal representation of z as generated by diff --git a/field/babybear/element_purego.go b/field/babybear/element_purego.go index 9d0d199466..38eb755db7 100644 --- a/field/babybear/element_purego.go +++ b/field/babybear/element_purego.go @@ -16,8 +16,6 @@ package babybear -import "math/bits" - // MulBy3 x *= 3 (mod q) func MulBy3(x *Element) { var y Element @@ -46,37 +44,23 @@ func fromMont(z *Element) { func reduce(z *Element) { _reduceGeneric(z) } +func montReduce(v uint64) uint32 { + const rBits = 32 + const r = 1 << rBits + m := (v * qInvNeg) % r + t := uint32((v + m*q) >> rBits) + if t >= q { + t -= q + } + return t +} // Mul z = x * y (mod q) // // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - - // In fact, since the modulus R fits on one register, the CIOS algorithm gets reduced to standard REDC (textbook Montgomery reduction): - // hi, lo := x * y - // m := (lo * qInvNeg) mod R - // (*) r := (hi * R + lo + m * q) / R - // reduce r if necessary - - // On the emphasized line, we get r = hi + (lo + m * q) / R - // If we write hi2, lo2 = m * q then R | m * q - lo2 ⇒ R | (lo * qInvNeg) q - lo2 = -lo - lo2 - // This shows lo + lo2 = 0 mod R. i.e. lo + lo2 = 0 if lo = 0 and R otherwise. - // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) - // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. - - hi, lo := bits.Mul32(x[0], y[0]) - if lo != 0 { - hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow - } - m := lo * qInvNeg - hi2, _ := bits.Mul32(m, q) - - hi2 += hi - if hi2 >= q { - hi2 -= q - } - z[0] = hi2 - + v := uint64(x[0]) * uint64(y[0]) + z[0] = montReduce(v) return z } @@ -85,32 +69,8 @@ func (z *Element) Mul(x, y *Element) *Element { // x must be less than q func (z *Element) Square(x *Element) *Element { // see Mul for algorithm documentation - - // In fact, since the modulus R fits on one register, the CIOS algorithm gets reduced to standard REDC (textbook Montgomery reduction): - // hi, lo := x * y - // m := (lo * qInvNeg) mod R - // (*) r := (hi * R + lo + m * q) / R - // reduce r if necessary - - // On the emphasized line, we get r = hi + (lo + m * q) / R - // If we write hi2, lo2 = m * q then R | m * q - lo2 ⇒ R | (lo * qInvNeg) q - lo2 = -lo - lo2 - // This shows lo + lo2 = 0 mod R. i.e. lo + lo2 = 0 if lo = 0 and R otherwise. - // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) - // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. - - hi, lo := bits.Mul32(x[0], x[0]) - if lo != 0 { - hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow - } - m := lo * qInvNeg - hi2, _ := bits.Mul32(m, q) - - hi2 += hi - if hi2 >= q { - hi2 -= q - } - z[0] = hi2 - + v := uint64(x[0]) * uint64(x[0]) + z[0] = montReduce(v) return z } diff --git a/field/babybear/element_test.go b/field/babybear/element_test.go index 14034e42a5..161c977b72 100644 --- a/field/babybear/element_test.go +++ b/field/babybear/element_test.go @@ -932,14 +932,6 @@ func TestElementMul(t *testing.T) { c.Mul(&a.element, &r) d.Mul(&a.bigint, &rb).Mod(&d, Modulus()) - // checking generic impl against asm path - var cGeneric Element - _mulGeneric(&cGeneric, &a.element, &r) - if !cGeneric.Equal(&c) { - // need to give context to failing error. - return false - } - if c.BigInt(&e).Cmp(&d) != 0 { return false } @@ -962,17 +954,6 @@ func TestElementMul(t *testing.T) { genB, )) - properties.Property("Mul: assembly implementation must be consistent with generic one", prop.ForAll( - func(a, b testPairElement) bool { - var c, d Element - c.Mul(&a.element, &b.element) - _mulGeneric(&d, &a.element, &b.element) - return c.Equal(&d) - }, - genA, - genB, - )) - specialValueTest := func() { // test special values against special values testValues := make([]Element, len(staticTestValues)) @@ -991,13 +972,6 @@ func TestElementMul(t *testing.T) { c.Mul(&a, &b) d.Mul(&aBig, &bBig).Mod(&d, Modulus()) - // checking asm against generic impl - var cGeneric Element - _mulGeneric(&cGeneric, &a, &b) - if !cGeneric.Equal(&c) { - t.Fatal("Mul failed special test values: asm and generic impl don't match") - } - if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index 4e6e9e862c..0813a3d46e 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -74,6 +74,7 @@ type FieldConfig struct { UseAddChain bool Word Word // 32 iff Q < 2^32, else 64 + F31 bool // 31 bits field // asm code generation GenerateOpsAMD64 bool @@ -89,7 +90,6 @@ type Word struct { TypeUpper string // Uint32 or Uint64 Add string // Add64 or Add32 Sub string // Sub64 or Sub32 - Mul string // Mul64 or Mul32 Len string // Len64 or Len32 } @@ -114,6 +114,7 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) } // pre compute field constants F.NbBits = bModulus.BitLen() + F.F31 = F.NbBits == 31 F.NbWords = len(bModulus.Bits()) F.NbWordsLastIndex = F.NbWords - 1 @@ -132,7 +133,6 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) F.Word.TypeUpper = "Uint64" F.Word.Add = "Add64" F.Word.Sub = "Sub64" - F.Word.Mul = "Mul64" F.Word.Len = "Len64" if F.NbBits < 32 { F.Word.BitSize = 32 @@ -141,17 +141,13 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) F.Word.TypeUpper = "Uint32" F.Word.Add = "Add32" F.Word.Sub = "Sub32" - F.Word.Mul = "Mul32" F.Word.Len = "Len32" } F.NbBytes = F.NbWords * F.Word.ByteSize // setting qInverse - radix := uint(64) - if F.Word.BitSize == 32 { - radix = 32 - } + radix := uint(F.Word.BitSize) _r := big.NewInt(1) _r.Lsh(_r, uint(F.NbWords)*radix) diff --git a/field/generator/generator.go b/field/generator/generator.go index 22e0292778..cca9c36721 100644 --- a/field/generator/generator.go +++ b/field/generator/generator.go @@ -144,7 +144,7 @@ func GenerateFF(F *config.FieldConfig, outputDir, asmDirBuildPath, asmDirInclude g.Go(generate("element.go", sourceFiles)) g.Go(generate("doc.go", []string{element.Doc})) g.Go(generate("vector.go", []string{element.Vector})) - g.Go(generate("arith.go", []string{element.Arith})) + g.Go(generate("arith.go", []string{element.Arith}, Only(F.F31))) g.Go(generate("element_test.go", testFiles)) g.Go(generate("vector_test.go", []string{element.TestVector})) diff --git a/field/generator/internal/templates/element/arith.go b/field/generator/internal/templates/element/arith.go index 427529f2a8..06a7805588 100644 --- a/field/generator/internal/templates/element/arith.go +++ b/field/generator/internal/templates/element/arith.go @@ -6,42 +6,42 @@ import ( ) // madd0 hi = a*b + c (discards lo bits) -func madd0(a, b, c {{$.Word.TypeLower}}) (hi {{$.Word.TypeLower}}) { - var carry, lo {{$.Word.TypeLower}} - hi, lo = bits.{{$.Word.Mul}}(a, b) - _, carry = bits.{{$.Word.Add}}(lo, c, 0) - hi, _ = bits.{{$.Word.Add}}(hi, 0, carry) +func madd0(a, b, c uint64) (hi uint64) { + var carry, lo uint64 + hi, lo = bits.Mul64(a, b) + _, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) return } // madd1 hi, lo = a*b + c -func madd1(a, b, c {{$.Word.TypeLower}}) (hi {{$.Word.TypeLower}}, lo {{$.Word.TypeLower}}) { - var carry {{$.Word.TypeLower}} - hi, lo = bits.{{$.Word.Mul}}(a, b) - lo, carry = bits.{{$.Word.Add}}(lo, c, 0) - hi, _ = bits.{{$.Word.Add}}(hi, 0, carry) +func madd1(a, b, c uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) return } // madd2 hi, lo = a*b + c + d -func madd2(a, b, c, d {{$.Word.TypeLower}}) (hi {{$.Word.TypeLower}}, lo {{$.Word.TypeLower}}) { - var carry {{$.Word.TypeLower}} - hi, lo = bits.{{$.Word.Mul}}(a, b) - c, carry = bits.{{$.Word.Add}}(c, d, 0) - hi, _ = bits.{{$.Word.Add}}(hi, 0, carry) - lo, carry = bits.{{$.Word.Add}}(lo, c, 0) - hi, _ = bits.{{$.Word.Add}}(hi, 0, carry) +func madd2(a, b, c, d uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) return } -func madd3(a, b, c, d, e {{$.Word.TypeLower}}) (hi {{$.Word.TypeLower}}, lo {{$.Word.TypeLower}}) { - var carry {{$.Word.TypeLower}} - hi, lo = bits.{{$.Word.Mul}}(a, b) - c, carry = bits.{{$.Word.Add}}(c, d, 0) - hi, _ = bits.{{$.Word.Add}}(hi, 0, carry) - lo, carry = bits.{{$.Word.Add}}(lo, c, 0) - hi, _ = bits.{{$.Word.Add}}(hi, e, carry) +func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, e, carry) return } diff --git a/field/generator/internal/templates/element/base.go b/field/generator/internal/templates/element/base.go index 57ffc01b19..ab4c20d24a 100644 --- a/field/generator/internal/templates/element/base.go +++ b/field/generator/internal/templates/element/base.go @@ -42,9 +42,9 @@ const ( // Field modulus q const ( {{- range $i := $.NbWordsIndexesFull}} - q{{$i}} {{$.Word.TypeLower}} = {{index $.Q $i}} + q{{$i}} = {{index $.Q $i}} {{- if eq $.NbWords 1}} - q {{$.Word.TypeLower}} = q0 + q = q0 {{- end}} {{- end}} ) @@ -66,7 +66,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg {{$.Word.TypeLower}} = {{index .QInverse 0}} +const qInvNeg = {{index .QInverse 0}} {{- if eq .NbWords 4}} // mu = 2^288 / q needed for partial Barrett reduction @@ -85,8 +85,8 @@ func init() { // v.SetUint64(...) func New{{.ElementName}}(v uint64) {{.ElementName}} { {{- if eq .Word.BitSize 32}} - z := {{.ElementName}}{ {{$.Word.TypeLower}}(v % uint64(q0)) } - z.Mul(&z, &rSquare) + z := {{.ElementName}}{ uint32(v % uint64(q0)) } + z.toMont() return z {{- else }} z := {{.ElementName}}{ v } @@ -97,15 +97,14 @@ func New{{.ElementName}}(v uint64) {{.ElementName}} { // SetUint64 sets z to v and returns z func (z *{{.ElementName}}) SetUint64(v uint64) *{{.ElementName}} { -{{- if eq .Word.BitSize 32}} - // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form - *z = {{.ElementName}}{ {{$.Word.TypeLower}}(v % uint64(q0)) } - return z.Mul(z, &rSquare) // z.toMont() -{{- else }} // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form - *z = {{.ElementName}}{ v } - return z.Mul(z, &rSquare) // z.toMont() -{{- end}} + {{- if .F31}} + *z = {{.ElementName}}{ uint32(v % uint64(q0)) } + return z.toMont() + {{- else }} + *z = {{.ElementName}}{ v } + return z.Mul(z, &rSquare) // z.toMont() + {{- end}} } // SetInt64 sets z to v and returns z @@ -417,8 +416,8 @@ func (z *{{.ElementName}}) Add( x, y *{{.ElementName}}) *{{.ElementName}} { {{- if eq .NbWords 1}} {{ $hasCarry := (not $.NoCarry)}} {{- if $hasCarry}} - var carry {{$.Word.TypeLower}} - z[0], carry = bits.{{$.Word.Add}}(x[0], y[0], 0) + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) if carry != 0 || z[0] >= q { z[0] -= q } @@ -568,6 +567,7 @@ func (z *{{.ElementName}}) Select(c int, x0 *{{.ElementName}}, x1 *{{.ElementNam return z } +{{- if ne .NbWords 1}} // _mulGeneric is unoptimized textbook CIOS // it is a fallback solution on x86 when ADX instruction set is not available // and is used for testing purposes. @@ -576,25 +576,30 @@ func _mulGeneric(z,x,y *{{.ElementName}}) { {{ template "mul_cios" dict "all" . "V1" "x" "V2" "y"}} {{ template "reduce" . }} } +{{- end}} func _fromMontGeneric(z *{{.ElementName}}) { - // the following lines implement z = z * 1 - // with a modified CIOS montgomery multiplication - // see Mul for algorithm documentation - {{- range $j := .NbWordsIndexesFull}} - { - // m = z[0]n'[0] mod W - m := z[0] * qInvNeg - C := madd0(m, q0, z[0]) - {{- range $i := $.NbWordsIndexesNoZero}} - C, z[{{sub $i 1}}] = madd2(m, q{{$i}}, z[{{$i}}], C) + {{- if .F31}} + z[0] = montReduce(uint64(z[0])) + {{- else}} + // the following lines implement z = z * 1 + // with a modified CIOS montgomery multiplication + // see Mul for algorithm documentation + {{- range $j := .NbWordsIndexesFull}} + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + {{- range $i := $.NbWordsIndexesNoZero}} + C, z[{{sub $i 1}}] = madd2(m, q{{$i}}, z[{{$i}}], C) + {{- end}} + z[{{sub $.NbWords 1}}] = C + } {{- end}} - z[{{sub $.NbWords 1}}] = C - } - {{- end}} - {{ template "reduce" .}} + {{ template "reduce" .}} + {{- end}} } func _reduceGeneric(z *{{.ElementName}}) { diff --git a/field/generator/internal/templates/element/conv.go b/field/generator/internal/templates/element/conv.go index 8b021d4a69..1587bdb7fb 100644 --- a/field/generator/internal/templates/element/conv.go +++ b/field/generator/internal/templates/element/conv.go @@ -13,7 +13,13 @@ var rSquare = {{.ElementName}}{ // toMont converts z to Montgomery form // sets and returns z = z * r² func (z *{{.ElementName}}) toMont() *{{.ElementName}} { - return z.Mul(z, &rSquare) + {{- if .F31}} + const rBits = 32 + z[0] = uint32((uint64(z[0]) << rBits) % q) + return z + {{- else}} + return z.Mul(z, &rSquare) + {{- end}} } // String returns the decimal representation of z as generated by @@ -205,23 +211,23 @@ func (z *{{.ElementName}}) SetBigInt(v *big.Int) *{{.ElementName}} { func (z *{{.ElementName}}) setBigInt(v *big.Int) *{{.ElementName}} { vBits := v.Bits() - {{- if eq .Word.BitSize 32}} - // we assume v < q, so even if big.Int words are on 64bits, we can safely cast them to 32bits - for i := 0; i < len(vBits); i++ { - z[i] = {{$.Word.TypeLower}}(vBits[i]) - } + {{- if .F31}} + // we assume v < q, so even if big.Int words are on 64bits, we can safely cast them to 32bits + for i := 0; i < len(vBits); i++ { + z[i] = uint32(vBits[i]) + } {{- else}} if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { - z[i] = {{$.Word.TypeLower}}(vBits[i]) + z[i] = uint64(vBits[i]) } } else { for i := 0; i < len(vBits); i++ { if i%2 == 0 { - z[i/2] = {{$.Word.TypeLower}}(vBits[i]) + z[i/2] = uint64(vBits[i]) } else { - z[i/2] |= {{$.Word.TypeLower}}(vBits[i]) << 32 + z[i/2] |= uint64(vBits[i]) << 32 } } } diff --git a/field/generator/internal/templates/element/mul_cios.go b/field/generator/internal/templates/element/mul_cios.go index 965024efda..17c79cac69 100644 --- a/field/generator/internal/templates/element/mul_cios.go +++ b/field/generator/internal/templates/element/mul_cios.go @@ -47,7 +47,7 @@ const MulCIOS = ` // ----------------------------------- // First loop {{ if eq $j 0}} - C, t[0] = bits.{{$.all.Word.Mul}}({{$.V2}}[{{$j}}], {{$.V1}}[0]) + C, t[0] = bits.Mul64({{$.V2}}[{{$j}}], {{$.V1}}[0]) {{- range $i := $.all.NbWordsIndexesNoZero}} C, t[{{$i}}] = madd1({{$.V2}}[{{$j}}], {{$.V1}}[{{$i}}], C) {{- end}} @@ -110,12 +110,12 @@ const MulCIOS = ` // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. - hi, lo := bits.{{$.all.Word.Mul}}({{$.V1}}[0], {{$.V2}}[0]) + hi, lo := bits.Mul64({{$.V1}}[0], {{$.V2}}[0]) if lo != 0 { hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow } m := lo * qInvNeg - hi2, _ := bits.{{$.all.Word.Mul}}(m, q) + hi2, _ := bits.Mul64(m, q) {{ $hasCarry := (not $.all.NoCarry)}} {{- if $hasCarry}} r, carry := bits.{{$.all.Word.Add}}(hi2, hi, 0) diff --git a/field/generator/internal/templates/element/ops_purego.go b/field/generator/internal/templates/element/ops_purego.go index 498b7e1ae6..d35aa32a06 100644 --- a/field/generator/internal/templates/element/ops_purego.go +++ b/field/generator/internal/templates/element/ops_purego.go @@ -2,7 +2,9 @@ package element const OpsNoAsm = ` +{{- if not $.F31}} import "math/bits" +{{- end}} {{ $mulConsts := list 3 5 13 }} {{- range $i := $mulConsts }} @@ -42,6 +44,19 @@ func reduce(z *{{.ElementName}}) { _reduceGeneric(z) } +{{- if $.F31}} +func montReduce(v uint64) uint32 { + const rBits = 32 + const r = 1 << rBits + m := (v * qInvNeg ) % r + t := uint32((v + m * q) >> rBits) + if t >= q { + t -= q + } + return t +} +{{- end}} + // Mul z = x * y (mod q) {{- if $.NoCarry}} // @@ -49,7 +64,12 @@ func reduce(z *{{.ElementName}}) { {{- end }} func (z *{{.ElementName}}) Mul(x, y *{{.ElementName}}) *{{.ElementName}} { {{- if eq $.NbWords 1}} - {{ template "mul_cios_one_limb" dict "all" . "V1" "x" "V2" "y" }} + {{- if $.F31}} + v := uint64(x[0]) * uint64(y[0]) + z[0] = montReduce(v) + {{- else}} + {{ template "mul_cios_one_limb" dict "all" . "V1" "x" "V2" "y" }} + {{- end}} {{- else }} {{ mul_doc $.NoCarry }} {{- if $.NoCarry}} @@ -70,7 +90,12 @@ func (z *{{.ElementName}}) Mul(x, y *{{.ElementName}}) *{{.ElementName}} { func (z *{{.ElementName}}) Square(x *{{.ElementName}}) *{{.ElementName}} { // see Mul for algorithm documentation {{- if eq $.NbWords 1}} - {{ template "mul_cios_one_limb" dict "all" . "V1" "x" "V2" "x" }} + {{- if $.F31}} + v := uint64(x[0]) * uint64(x[0]) + z[0] = montReduce(v) + {{- else}} + {{ template "mul_cios_one_limb" dict "all" . "V1" "x" "V2" "x" }} + {{- end}} {{- else }} {{- if $.NoCarry}} {{ template "mul_nocarry" dict "all" . "V1" "x" "V2" "x"}} diff --git a/field/generator/internal/templates/element/tests.go b/field/generator/internal/templates/element/tests.go index 7008294565..bdcc83956c 100644 --- a/field/generator/internal/templates/element/tests.go +++ b/field/generator/internal/templates/element/tests.go @@ -730,7 +730,11 @@ func Test{{toTitle .ElementName}}LexicographicallyLargest(t *testing.T) { {{template "testBinaryOp" dict "all" . "Op" "Add"}} {{template "testBinaryOp" dict "all" . "Op" "Sub"}} +{{- if ne .NbWords 1}} {{template "testBinaryOp" dict "all" . "Op" "Mul" "GenericOp" "_mulGeneric"}} +{{- else}} +{{template "testBinaryOp" dict "all" . "Op" "Mul"}} +{{- end}} {{template "testBinaryOp" dict "all" . "Op" "Div"}} {{template "testBinaryOp" dict "all" . "Op" "Exp"}} diff --git a/field/goldilocks/element.go b/field/goldilocks/element.go index e48950ae15..207ee7fbb5 100644 --- a/field/goldilocks/element.go +++ b/field/goldilocks/element.go @@ -54,8 +54,8 @@ const ( // Field modulus q const ( - q0 uint64 = 18446744069414584321 - q uint64 = q0 + q0 = 18446744069414584321 + q = q0 ) var qElement = Element{ @@ -74,7 +74,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 18446744069414584319 +const qInvNeg = 18446744069414584319 func init() { _modulus.SetString("ffffffff00000001", 16) @@ -407,49 +407,6 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } -// _mulGeneric is unoptimized textbook CIOS -// it is a fallback solution on x86 when ADX instruction set is not available -// and is used for testing purposes. -func _mulGeneric(z, x, y *Element) { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - var t [2]uint64 - var D uint64 - var m, C uint64 - // ----------------------------------- - // First loop - - C, t[0] = bits.Mul64(y[0], x[0]) - - t[1], D = bits.Add64(t[1], C, 0) - - // m = t[0]n'[0] mod W - m = t[0] * qInvNeg - - // ----------------------------------- - // Second loop - C = madd0(m, q0, t[0]) - - t[0], C = bits.Add64(t[1], C, 0) - t[1], _ = bits.Add64(0, D, C) - - if t[1] != 0 { - // we need to reduce, we have a result on 2 words - z[0], _ = bits.Sub64(t[0], q0, 0) - return - } - - // copy t into z - z[0] = t[0] - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - z[0] -= q - } -} - func _fromMontGeneric(z *Element) { // the following lines implement z = z * 1 // with a modified CIOS montgomery multiplication diff --git a/field/goldilocks/element_test.go b/field/goldilocks/element_test.go index 454d057dbf..52fa5e2691 100644 --- a/field/goldilocks/element_test.go +++ b/field/goldilocks/element_test.go @@ -932,14 +932,6 @@ func TestElementMul(t *testing.T) { c.Mul(&a.element, &r) d.Mul(&a.bigint, &rb).Mod(&d, Modulus()) - // checking generic impl against asm path - var cGeneric Element - _mulGeneric(&cGeneric, &a.element, &r) - if !cGeneric.Equal(&c) { - // need to give context to failing error. - return false - } - if c.BigInt(&e).Cmp(&d) != 0 { return false } @@ -962,17 +954,6 @@ func TestElementMul(t *testing.T) { genB, )) - properties.Property("Mul: assembly implementation must be consistent with generic one", prop.ForAll( - func(a, b testPairElement) bool { - var c, d Element - c.Mul(&a.element, &b.element) - _mulGeneric(&d, &a.element, &b.element) - return c.Equal(&d) - }, - genA, - genB, - )) - specialValueTest := func() { // test special values against special values testValues := make([]Element, len(staticTestValues)) @@ -991,13 +972,6 @@ func TestElementMul(t *testing.T) { c.Mul(&a, &b) d.Mul(&aBig, &bBig).Mod(&d, Modulus()) - // checking asm against generic impl - var cGeneric Element - _mulGeneric(&cGeneric, &a, &b) - if !cGeneric.Equal(&c) { - t.Fatal("Mul failed special test values: asm and generic impl don't match") - } - if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } diff --git a/field/koalabear/arith.go b/field/koalabear/arith.go index aab49151f6..0252a5bd1e 100644 --- a/field/koalabear/arith.go +++ b/field/koalabear/arith.go @@ -21,40 +21,40 @@ import ( ) // madd0 hi = a*b + c (discards lo bits) -func madd0(a, b, c uint32) (hi uint32) { - var carry, lo uint32 - hi, lo = bits.Mul32(a, b) - _, carry = bits.Add32(lo, c, 0) - hi, _ = bits.Add32(hi, 0, carry) +func madd0(a, b, c uint64) (hi uint64) { + var carry, lo uint64 + hi, lo = bits.Mul64(a, b) + _, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) return } // madd1 hi, lo = a*b + c -func madd1(a, b, c uint32) (hi uint32, lo uint32) { - var carry uint32 - hi, lo = bits.Mul32(a, b) - lo, carry = bits.Add32(lo, c, 0) - hi, _ = bits.Add32(hi, 0, carry) +func madd1(a, b, c uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) return } // madd2 hi, lo = a*b + c + d -func madd2(a, b, c, d uint32) (hi uint32, lo uint32) { - var carry uint32 - hi, lo = bits.Mul32(a, b) - c, carry = bits.Add32(c, d, 0) - hi, _ = bits.Add32(hi, 0, carry) - lo, carry = bits.Add32(lo, c, 0) - hi, _ = bits.Add32(hi, 0, carry) +func madd2(a, b, c, d uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) return } -func madd3(a, b, c, d, e uint32) (hi uint32, lo uint32) { - var carry uint32 - hi, lo = bits.Mul32(a, b) - c, carry = bits.Add32(c, d, 0) - hi, _ = bits.Add32(hi, 0, carry) - lo, carry = bits.Add32(lo, c, 0) - hi, _ = bits.Add32(hi, e, carry) +func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, e, carry) return } diff --git a/field/koalabear/element.go b/field/koalabear/element.go index 645d192926..48a74656e3 100644 --- a/field/koalabear/element.go +++ b/field/koalabear/element.go @@ -54,8 +54,8 @@ const ( // Field modulus q const ( - q0 uint32 = 2130706433 - q uint32 = q0 + q0 = 2130706433 + q = q0 ) var qElement = Element{ @@ -74,7 +74,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint32 = 2130706431 +const qInvNeg = 2130706431 func init() { _modulus.SetString("7f000001", 16) @@ -88,7 +88,7 @@ func init() { // v.SetUint64(...) func NewElement(v uint64) Element { z := Element{uint32(v % uint64(q0))} - z.Mul(&z, &rSquare) + z.toMont() return z } @@ -96,7 +96,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{uint32(v % uint64(q0))} - return z.Mul(z, &rSquare) // z.toMont() + return z.toMont() } // SetInt64 sets z to v and returns z @@ -393,64 +393,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } -// _mulGeneric is unoptimized textbook CIOS -// it is a fallback solution on x86 when ADX instruction set is not available -// and is used for testing purposes. -func _mulGeneric(z, x, y *Element) { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - var t [2]uint32 - var D uint32 - var m, C uint32 - // ----------------------------------- - // First loop - - C, t[0] = bits.Mul32(y[0], x[0]) - - t[1], D = bits.Add32(t[1], C, 0) - - // m = t[0]n'[0] mod W - m = t[0] * qInvNeg - - // ----------------------------------- - // Second loop - C = madd0(m, q0, t[0]) - - t[0], C = bits.Add32(t[1], C, 0) - t[1], _ = bits.Add32(0, D, C) - - if t[1] != 0 { - // we need to reduce, we have a result on 2 words - z[0], _ = bits.Sub32(t[0], q0, 0) - return - } - - // copy t into z - z[0] = t[0] - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - z[0] -= q - } -} - func _fromMontGeneric(z *Element) { - // the following lines implement z = z * 1 - // with a modified CIOS montgomery multiplication - // see Mul for algorithm documentation - { - // m = z[0]n'[0] mod W - m := z[0] * qInvNeg - C := madd0(m, q0, z[0]) - z[0] = C - } - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - z[0] -= q - } + z[0] = montReduce(uint64(z[0])) } func _reduceGeneric(z *Element) { @@ -576,7 +520,9 @@ var rSquare = Element{ // toMont converts z to Montgomery form // sets and returns z = z * r² func (z *Element) toMont() *Element { - return z.Mul(z, &rSquare) + const rBits = 32 + z[0] = uint32((uint64(z[0]) << rBits) % q) + return z } // String returns the decimal representation of z as generated by diff --git a/field/koalabear/element_purego.go b/field/koalabear/element_purego.go index 628e2c4afb..243b8d13d3 100644 --- a/field/koalabear/element_purego.go +++ b/field/koalabear/element_purego.go @@ -16,8 +16,6 @@ package koalabear -import "math/bits" - // MulBy3 x *= 3 (mod q) func MulBy3(x *Element) { var y Element @@ -46,37 +44,23 @@ func fromMont(z *Element) { func reduce(z *Element) { _reduceGeneric(z) } +func montReduce(v uint64) uint32 { + const rBits = 32 + const r = 1 << rBits + m := (v * qInvNeg) % r + t := uint32((v + m*q) >> rBits) + if t >= q { + t -= q + } + return t +} // Mul z = x * y (mod q) // // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - - // In fact, since the modulus R fits on one register, the CIOS algorithm gets reduced to standard REDC (textbook Montgomery reduction): - // hi, lo := x * y - // m := (lo * qInvNeg) mod R - // (*) r := (hi * R + lo + m * q) / R - // reduce r if necessary - - // On the emphasized line, we get r = hi + (lo + m * q) / R - // If we write hi2, lo2 = m * q then R | m * q - lo2 ⇒ R | (lo * qInvNeg) q - lo2 = -lo - lo2 - // This shows lo + lo2 = 0 mod R. i.e. lo + lo2 = 0 if lo = 0 and R otherwise. - // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) - // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. - - hi, lo := bits.Mul32(x[0], y[0]) - if lo != 0 { - hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow - } - m := lo * qInvNeg - hi2, _ := bits.Mul32(m, q) - - hi2 += hi - if hi2 >= q { - hi2 -= q - } - z[0] = hi2 - + v := uint64(x[0]) * uint64(y[0]) + z[0] = montReduce(v) return z } @@ -85,32 +69,8 @@ func (z *Element) Mul(x, y *Element) *Element { // x must be less than q func (z *Element) Square(x *Element) *Element { // see Mul for algorithm documentation - - // In fact, since the modulus R fits on one register, the CIOS algorithm gets reduced to standard REDC (textbook Montgomery reduction): - // hi, lo := x * y - // m := (lo * qInvNeg) mod R - // (*) r := (hi * R + lo + m * q) / R - // reduce r if necessary - - // On the emphasized line, we get r = hi + (lo + m * q) / R - // If we write hi2, lo2 = m * q then R | m * q - lo2 ⇒ R | (lo * qInvNeg) q - lo2 = -lo - lo2 - // This shows lo + lo2 = 0 mod R. i.e. lo + lo2 = 0 if lo = 0 and R otherwise. - // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) - // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. - - hi, lo := bits.Mul32(x[0], x[0]) - if lo != 0 { - hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow - } - m := lo * qInvNeg - hi2, _ := bits.Mul32(m, q) - - hi2 += hi - if hi2 >= q { - hi2 -= q - } - z[0] = hi2 - + v := uint64(x[0]) * uint64(x[0]) + z[0] = montReduce(v) return z } diff --git a/field/koalabear/element_test.go b/field/koalabear/element_test.go index 13899466b6..ac6f3a8a70 100644 --- a/field/koalabear/element_test.go +++ b/field/koalabear/element_test.go @@ -932,14 +932,6 @@ func TestElementMul(t *testing.T) { c.Mul(&a.element, &r) d.Mul(&a.bigint, &rb).Mod(&d, Modulus()) - // checking generic impl against asm path - var cGeneric Element - _mulGeneric(&cGeneric, &a.element, &r) - if !cGeneric.Equal(&c) { - // need to give context to failing error. - return false - } - if c.BigInt(&e).Cmp(&d) != 0 { return false } @@ -962,17 +954,6 @@ func TestElementMul(t *testing.T) { genB, )) - properties.Property("Mul: assembly implementation must be consistent with generic one", prop.ForAll( - func(a, b testPairElement) bool { - var c, d Element - c.Mul(&a.element, &b.element) - _mulGeneric(&d, &a.element, &b.element) - return c.Equal(&d) - }, - genA, - genB, - )) - specialValueTest := func() { // test special values against special values testValues := make([]Element, len(staticTestValues)) @@ -991,13 +972,6 @@ func TestElementMul(t *testing.T) { c.Mul(&a, &b) d.Mul(&aBig, &bBig).Mod(&d, Modulus()) - // checking asm against generic impl - var cGeneric Element - _mulGeneric(&cGeneric, &a, &b) - if !cGeneric.Equal(&c) { - t.Fatal("Mul failed special test values: asm and generic impl don't match") - } - if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } From db78c7b96921543e66e11d49e48c6dc53be4444e Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 6 Dec 2024 20:43:46 +0000 Subject: [PATCH 33/74] style: more cleaning --- .../internal/templates/element/mul_cios.go | 21 ++++++------------- field/goldilocks/element_purego.go | 2 -- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/field/generator/internal/templates/element/mul_cios.go b/field/generator/internal/templates/element/mul_cios.go index 17c79cac69..4a726df349 100644 --- a/field/generator/internal/templates/element/mul_cios.go +++ b/field/generator/internal/templates/element/mul_cios.go @@ -116,21 +116,12 @@ const MulCIOS = ` } m := lo * qInvNeg hi2, _ := bits.Mul64(m, q) - {{ $hasCarry := (not $.all.NoCarry)}} - {{- if $hasCarry}} - r, carry := bits.{{$.all.Word.Add}}(hi2, hi, 0) - if carry != 0 || r >= q { - // we need to reduce - r -= q - } - z[0] = r - {{- else}} - hi2 += hi - if hi2 >= q { - hi2 -= q - } - z[0] = hi2 - {{- end}} + r, carry := bits.{{$.all.Word.Add}}(hi2, hi, 0) + if carry != 0 || r >= q { + // we need to reduce + r -= q + } + z[0] = r {{ end }} ` diff --git a/field/goldilocks/element_purego.go b/field/goldilocks/element_purego.go index c7b2647b18..2dfbd7dbd3 100644 --- a/field/goldilocks/element_purego.go +++ b/field/goldilocks/element_purego.go @@ -68,7 +68,6 @@ func (z *Element) Mul(x, y *Element) *Element { } m := lo * qInvNeg hi2, _ := bits.Mul64(m, q) - r, carry := bits.Add64(hi2, hi, 0) if carry != 0 || r >= q { // we need to reduce @@ -101,7 +100,6 @@ func (z *Element) Square(x *Element) *Element { } m := lo * qInvNeg hi2, _ := bits.Mul64(m, q) - r, carry := bits.Add64(hi2, hi, 0) if carry != 0 || r >= q { // we need to reduce From 5c569d00faed26c93da302bbffe4f68e3fa5e676 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 6 Dec 2024 21:04:44 +0000 Subject: [PATCH 34/74] feat: cleaner mont mul, slower --- field/babybear/element_purego.go | 6 ++---- field/generator/internal/templates/element/ops_purego.go | 8 +++----- field/koalabear/element_purego.go | 6 ++---- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/field/babybear/element_purego.go b/field/babybear/element_purego.go index 38eb755db7..e618e83fd8 100644 --- a/field/babybear/element_purego.go +++ b/field/babybear/element_purego.go @@ -45,10 +45,8 @@ func reduce(z *Element) { _reduceGeneric(z) } func montReduce(v uint64) uint32 { - const rBits = 32 - const r = 1 << rBits - m := (v * qInvNeg) % r - t := uint32((v + m*q) >> rBits) + m := uint32(v) * qInvNeg + t := uint32((v + uint64(m)*q) >> 32) if t >= q { t -= q } diff --git a/field/generator/internal/templates/element/ops_purego.go b/field/generator/internal/templates/element/ops_purego.go index d35aa32a06..53f257a089 100644 --- a/field/generator/internal/templates/element/ops_purego.go +++ b/field/generator/internal/templates/element/ops_purego.go @@ -2,7 +2,7 @@ package element const OpsNoAsm = ` -{{- if not $.F31}} +{{- if not .F31}} import "math/bits" {{- end}} @@ -46,10 +46,8 @@ func reduce(z *{{.ElementName}}) { {{- if $.F31}} func montReduce(v uint64) uint32 { - const rBits = 32 - const r = 1 << rBits - m := (v * qInvNeg ) % r - t := uint32((v + m * q) >> rBits) + m := uint32(v) * qInvNeg + t := uint32((v + uint64(m) * q) >> 32) if t >= q { t -= q } diff --git a/field/koalabear/element_purego.go b/field/koalabear/element_purego.go index 243b8d13d3..137f5eb77d 100644 --- a/field/koalabear/element_purego.go +++ b/field/koalabear/element_purego.go @@ -45,10 +45,8 @@ func reduce(z *Element) { _reduceGeneric(z) } func montReduce(v uint64) uint32 { - const rBits = 32 - const r = 1 << rBits - m := (v * qInvNeg) % r - t := uint32((v + m*q) >> rBits) + m := uint32(v) * qInvNeg + t := uint32((v + uint64(m)*q) >> 32) if t >= q { t -= q } From 0a204127039e1214752ecba52763811f79938e9c Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 6 Dec 2024 21:06:20 +0000 Subject: [PATCH 35/74] fix integration test --- field/generator/config/field_config.go | 2 +- field/generator/generator.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index 0813a3d46e..dfbb76cf79 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -114,7 +114,7 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) } // pre compute field constants F.NbBits = bModulus.BitLen() - F.F31 = F.NbBits == 31 + F.F31 = F.NbBits <= 31 F.NbWords = len(bModulus.Bits()) F.NbWordsLastIndex = F.NbWords - 1 diff --git a/field/generator/generator.go b/field/generator/generator.go index cca9c36721..7547db305a 100644 --- a/field/generator/generator.go +++ b/field/generator/generator.go @@ -144,7 +144,7 @@ func GenerateFF(F *config.FieldConfig, outputDir, asmDirBuildPath, asmDirInclude g.Go(generate("element.go", sourceFiles)) g.Go(generate("doc.go", []string{element.Doc})) g.Go(generate("vector.go", []string{element.Vector})) - g.Go(generate("arith.go", []string{element.Arith}, Only(F.F31))) + g.Go(generate("arith.go", []string{element.Arith}, Only(!F.F31))) g.Go(generate("element_test.go", testFiles)) g.Go(generate("vector_test.go", []string{element.TestVector})) From ac9720aeb428d2b83938863ec76b5cc381312e40 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 6 Dec 2024 21:14:55 +0000 Subject: [PATCH 36/74] style: more cleaning --- field/babybear/element.go | 2 +- .../internal/templates/element/base.go | 22 +++++++++---------- .../internal/templates/element/mul_cios.go | 22 +++++++++---------- field/koalabear/element.go | 2 +- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/field/babybear/element.go b/field/babybear/element.go index 63986904c7..2cc7891c9b 100644 --- a/field/babybear/element.go +++ b/field/babybear/element.go @@ -358,7 +358,7 @@ func (z *Element) Add(x, y *Element) *Element { // Double z = x + x (mod q), aka Lsh 1 func (z *Element) Double(x *Element) *Element { - z[0] = (x[0] << 1) + z[0] = x[0] << 1 if z[0] >= q { z[0] -= q } diff --git a/field/generator/internal/templates/element/base.go b/field/generator/internal/templates/element/base.go index ab4c20d24a..580fb580ff 100644 --- a/field/generator/internal/templates/element/base.go +++ b/field/generator/internal/templates/element/base.go @@ -84,15 +84,15 @@ func init() { // var v {{.ElementName}} // v.SetUint64(...) func New{{.ElementName}}(v uint64) {{.ElementName}} { -{{- if eq .Word.BitSize 32}} - z := {{.ElementName}}{ uint32(v % uint64(q0)) } - z.toMont() - return z -{{- else }} - z := {{.ElementName}}{ v } - z.Mul(&z, &rSquare) - return z -{{- end}} + {{- if .F31}} + z := {{.ElementName}}{ uint32(v % uint64(q0)) } + z.toMont() + return z + {{- else }} + z := {{.ElementName}}{ v } + z.Mul(&z, &rSquare) + return z + {{- end}} } // SetUint64 sets z to v and returns z @@ -460,8 +460,8 @@ func (z *{{.ElementName}}) Add( x, y *{{.ElementName}}) *{{.ElementName}} { // Double z = x + x (mod q), aka Lsh 1 func (z *{{.ElementName}}) Double( x *{{.ElementName}}) *{{.ElementName}} { {{- if eq .NbWords 1}} - {{- if lt .NbBits 32}} - z[0] = (x[0] << 1) + {{- if .F31}} + z[0] = x[0] << 1 if z[0] >= q { z[0] -= q } diff --git a/field/generator/internal/templates/element/mul_cios.go b/field/generator/internal/templates/element/mul_cios.go index 4a726df349..e68992da0f 100644 --- a/field/generator/internal/templates/element/mul_cios.go +++ b/field/generator/internal/templates/element/mul_cios.go @@ -39,9 +39,9 @@ package element // The same (arm64) unrolled Go code produce satisfying performance for WASM (compiled using TinyGo). const MulCIOS = ` {{ define "mul_cios" }} - var t [{{add .all.NbWords 1}}]{{$.all.Word.TypeLower}} - var D {{$.all.Word.TypeLower}} - var m, C {{$.all.Word.TypeLower}} + var t [{{add .all.NbWords 1}}]uint64 + var D uint64 + var m, C uint64 {{- range $j := .all.NbWordsIndexesFull}} // ----------------------------------- @@ -57,7 +57,7 @@ const MulCIOS = ` C, t[{{$i}}] = madd2({{$.V2}}[{{$j}}], {{$.V1}}[{{$i}}], t[{{$i}}], C) {{- end}} {{ end }} - t[{{$.all.NbWords}}], D = bits.{{$.all.Word.Add}}(t[{{$.all.NbWords}}], C, 0) + t[{{$.all.NbWords}}], D = bits.Add64(t[{{$.all.NbWords}}], C, 0) // m = t[0]n'[0] mod W m = t[0] * qInvNeg @@ -69,22 +69,22 @@ const MulCIOS = ` C, t[{{sub $i 1}}] = madd2(m, q{{$i}}, t[{{$i}}], C) {{- end}} - t[{{sub $.all.NbWords 1}}], C = bits.{{$.all.Word.Add}}(t[{{$.all.NbWords}}], C, 0) - t[{{$.all.NbWords}}], _ = bits.{{$.all.Word.Add}}(0, D, C) + t[{{sub $.all.NbWords 1}}], C = bits.Add64(t[{{$.all.NbWords}}], C, 0) + t[{{$.all.NbWords}}], _ = bits.Add64(0, D, C) {{- end}} if t[{{$.all.NbWords}}] != 0 { // we need to reduce, we have a result on {{add 1 $.all.NbWords}} words {{- if gt $.all.NbWords 1}} - var b {{$.all.Word.TypeLower}} + var b uint64 {{- end}} - z[0], {{- if gt $.all.NbWords 1}}b{{- else}}_{{- end}} = bits.{{$.all.Word.Sub}}(t[0], q0, 0) + z[0], {{- if gt $.all.NbWords 1}}b{{- else}}_{{- end}} = bits.Sub64(t[0], q0, 0) {{- range $i := .all.NbWordsIndexesNoZero}} {{- if eq $i $.all.NbWordsLastIndex}} - z[{{$i}}], _ = bits.{{$.all.Word.Sub}}(t[{{$i}}], q{{$i}}, b) + z[{{$i}}], _ = bits.Sub64(t[{{$i}}], q{{$i}}, b) {{- else }} - z[{{$i}}], b = bits.{{$.all.Word.Sub}}(t[{{$i}}], q{{$i}}, b) + z[{{$i}}], b = bits.Sub64(t[{{$i}}], q{{$i}}, b) {{- end}} {{- end}} return {{if $.ReturnZ }} z{{- end}} @@ -116,7 +116,7 @@ const MulCIOS = ` } m := lo * qInvNeg hi2, _ := bits.Mul64(m, q) - r, carry := bits.{{$.all.Word.Add}}(hi2, hi, 0) + r, carry := bits.Add64(hi2, hi, 0) if carry != 0 || r >= q { // we need to reduce r -= q diff --git a/field/koalabear/element.go b/field/koalabear/element.go index 48a74656e3..8dfde59c9c 100644 --- a/field/koalabear/element.go +++ b/field/koalabear/element.go @@ -358,7 +358,7 @@ func (z *Element) Add(x, y *Element) *Element { // Double z = x + x (mod q), aka Lsh 1 func (z *Element) Double(x *Element) *Element { - z[0] = (x[0] << 1) + z[0] = x[0] << 1 if z[0] >= q { z[0] -= q } From 725a47680f192790ee2276fb7dc87324e7b25ba6 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 6 Dec 2024 21:17:12 +0000 Subject: [PATCH 37/74] test: fix failing generator test --- field/generator/config/field_test.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/field/generator/config/field_test.go b/field/generator/config/field_test.go index 27cc112e6f..e6916c992c 100644 --- a/field/generator/config/field_test.go +++ b/field/generator/config/field_test.go @@ -43,7 +43,12 @@ func TestIntToMont(t *testing.T) { func(f *FieldConfig) (bool, error) { // test if using the same R i := big.NewInt(1) - i.Lsh(i, uint(f.Word.BitSize)*uint(f.NbWords)) + if f.F31 { + i.Lsh(i, 31*uint(f.NbWords)) + } else { + i.Lsh(i, 64*uint(f.NbWords)) + } + *i = f.ToMont(*i) err := bigIntMatchUint64Slice(i, f.RSquare) From 5b11602ad4cdaf1caa8cc3934c735d3a738aabb2 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 6 Dec 2024 21:41:30 +0000 Subject: [PATCH 38/74] test: fix field config test to use bitsize --- field/generator/config/field_config.go | 9 ++++----- field/generator/config/field_test.go | 14 +++++--------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index dfbb76cf79..2d992db463 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -134,7 +134,7 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) F.Word.Add = "Add64" F.Word.Sub = "Sub64" F.Word.Len = "Len64" - if F.NbBits < 32 { + if F.F31 { F.Word.BitSize = 32 F.Word.ByteSize = 4 F.Word.TypeLower = "uint32" @@ -179,9 +179,8 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) } // rsquare - _rSquare := big.NewInt(2) - exponent := big.NewInt(int64(F.NbWords) * int64(radix) * 2) - _rSquare.Exp(_rSquare, exponent, &bModulus) + _rSquare := big.NewInt(1) + _rSquare.Lsh(_rSquare, uint(F.NbWords)*radix*2).Mod(_rSquare, &bModulus) F.RSquare = toUint64Slice(_rSquare, F.NbWords) var one big.Int @@ -389,7 +388,7 @@ func (f *FieldConfig) FromMont(nonMont *big.Int, mont *big.Int) *FieldConfig { return f } f.halve(nonMont, mont) - for i := 1; i < f.NbWords*64; i++ { + for i := 1; i < f.NbWords*f.Word.BitSize; i++ { f.halve(nonMont, nonMont) } diff --git a/field/generator/config/field_test.go b/field/generator/config/field_test.go index e6916c992c..6b668b60cd 100644 --- a/field/generator/config/field_test.go +++ b/field/generator/config/field_test.go @@ -20,7 +20,7 @@ func TestIntToMont(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() - parameters.MinSuccessfulTests = 10 + parameters.MinSuccessfulTests = 20 properties := gopter.NewProperties(parameters) genF := genField(t) @@ -42,16 +42,12 @@ func TestIntToMont(t *testing.T) { properties.Property("turning R into montgomery form must match the R value from field", prop.ForAll( func(f *FieldConfig) (bool, error) { // test if using the same R - i := big.NewInt(1) - if f.F31 { - i.Lsh(i, 31*uint(f.NbWords)) - } else { - i.Lsh(i, 64*uint(f.NbWords)) - } + r := big.NewInt(1) + r.Lsh(r, uint(f.Word.BitSize)*uint(f.NbWords)) - *i = f.ToMont(*i) + *r = f.ToMont(*r) - err := bigIntMatchUint64Slice(i, f.RSquare) + err := bigIntMatchUint64Slice(r, f.RSquare) return err == nil, err }, genF), ) From 8b3b4d165f32c911c446f72c9c9ea1c855921e6a Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 6 Dec 2024 21:51:29 +0000 Subject: [PATCH 39/74] feat: on 31bit field better branch-less add and sub --- field/babybear/element.go | 20 ++--- .../internal/templates/element/base.go | 78 +++++++++++-------- field/koalabear/element.go | 20 ++--- 3 files changed, 67 insertions(+), 51 deletions(-) diff --git a/field/babybear/element.go b/field/babybear/element.go index 2cc7891c9b..a9f524fa84 100644 --- a/field/babybear/element.go +++ b/field/babybear/element.go @@ -349,29 +349,31 @@ func (z *Element) fromMont() *Element { // Add z = x + y (mod q) func (z *Element) Add(x, y *Element) *Element { - z[0] = x[0] + y[0] - if z[0] >= q { - z[0] -= q + t := x[0] + y[0] + if t >= q { + t -= q } + z[0] = t return z } // Double z = x + x (mod q), aka Lsh 1 func (z *Element) Double(x *Element) *Element { - z[0] = x[0] << 1 - if z[0] >= q { - z[0] -= q + t := x[0] << 1 + if t >= q { + t -= q } + z[0] = t return z } // Sub z = x - y (mod q) func (z *Element) Sub(x, y *Element) *Element { - var b uint32 - z[0], b = bits.Sub32(x[0], y[0], 0) + t, b := bits.Sub32(x[0], y[0], 0) if b != 0 { - z[0] += q + t += q } + z[0] = t return z } diff --git a/field/generator/internal/templates/element/base.go b/field/generator/internal/templates/element/base.go index 580fb580ff..9965edbd52 100644 --- a/field/generator/internal/templates/element/base.go +++ b/field/generator/internal/templates/element/base.go @@ -423,10 +423,11 @@ func (z *{{.ElementName}}) Add( x, y *{{.ElementName}}) *{{.ElementName}} { } return z {{- else}} - z[0] = x[0] + y[0] - if z[0] >= q { - z[0] -= q + t := x[0] + y[0] + if t >= q { + t -= q } + z[0] = t return z {{- end}} {{- else}} @@ -461,23 +462,25 @@ func (z *{{.ElementName}}) Add( x, y *{{.ElementName}}) *{{.ElementName}} { func (z *{{.ElementName}}) Double( x *{{.ElementName}}) *{{.ElementName}} { {{- if eq .NbWords 1}} {{- if .F31}} - z[0] = x[0] << 1 - if z[0] >= q { - z[0] -= q + t := x[0] << 1 + if t >= q { + t -= q } + z[0] = t + return z {{- else}} if x[0]&(1<<63) == (1 << 63) { - // if highest bit is set, then we have a carry to x + x, we shift and subtract q - z[0] = (x[0] << 1) - q + // if highest bit is set, then we have a carry to x + x, we shift and subtract q + z[0] = (x[0] << 1) - q } else { - // highest bit is not set, but x + x can still be >= q - z[0] = (x[0] << 1) - if z[0] >= q { - z[0] -= q - } + // highest bit is not set, but x + x can still be >= q + z[0] = (x[0] << 1) + if z[0] >= q { + z[0] -= q + } } + return z {{- end}} - return z {{- else}} {{ $hasCarry := or (not $.NoCarry) (gt $.NbWords 1)}} {{- if $hasCarry}} @@ -510,27 +513,36 @@ func (z *{{.ElementName}}) Double( x *{{.ElementName}}) *{{.ElementName}} { // Sub z = x - y (mod q) func (z *{{.ElementName}}) Sub( x, y *{{.ElementName}}) *{{.ElementName}} { - var b {{$.Word.TypeLower}} - z[0], b = bits.{{$.Word.Sub}}(x[0], y[0], 0) - {{- range $i := .NbWordsIndexesNoZero}} - z[{{$i}}], b = bits.{{$.Word.Sub}}(x[{{$i}}], y[{{$i}}], b) - {{- end}} - if b != 0 { - {{- if eq .NbWords 1}} - z[0] += q - {{- else}} - var c uint64 - z[0], c = bits.Add64(z[0], q0, 0) - {{- range $i := .NbWordsIndexesNoZero}} - {{- if eq $i $.NbWordsLastIndex}} - z[{{$i}}], _ = bits.Add64(z[{{$i}}], q{{$i}}, c) - {{- else}} - z[{{$i}}], c = bits.Add64(z[{{$i}}], q{{$i}}, c) + {{- if $.F31}} + t, b := bits.Sub32(x[0], y[0], 0) + if b != 0 { + t += q + } + z[0] = t + return z + {{- else}} + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + {{- range $i := .NbWordsIndexesNoZero}} + z[{{$i}}], b = bits.Sub64(x[{{$i}}], y[{{$i}}], b) + {{- end}} + if b != 0 { + {{- if eq .NbWords 1}} + z[0] += q + {{- else}} + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + {{- range $i := .NbWordsIndexesNoZero}} + {{- if eq $i $.NbWordsLastIndex}} + z[{{$i}}], _ = bits.Add64(z[{{$i}}], q{{$i}}, c) + {{- else}} + z[{{$i}}], c = bits.Add64(z[{{$i}}], q{{$i}}, c) + {{- end}} {{- end}} {{- end}} - {{- end}} - } - return z + } + return z + {{- end}} } diff --git a/field/koalabear/element.go b/field/koalabear/element.go index 8dfde59c9c..280a57d50f 100644 --- a/field/koalabear/element.go +++ b/field/koalabear/element.go @@ -349,29 +349,31 @@ func (z *Element) fromMont() *Element { // Add z = x + y (mod q) func (z *Element) Add(x, y *Element) *Element { - z[0] = x[0] + y[0] - if z[0] >= q { - z[0] -= q + t := x[0] + y[0] + if t >= q { + t -= q } + z[0] = t return z } // Double z = x + x (mod q), aka Lsh 1 func (z *Element) Double(x *Element) *Element { - z[0] = x[0] << 1 - if z[0] >= q { - z[0] -= q + t := x[0] << 1 + if t >= q { + t -= q } + z[0] = t return z } // Sub z = x - y (mod q) func (z *Element) Sub(x, y *Element) *Element { - var b uint32 - z[0], b = bits.Sub32(x[0], y[0], 0) + t, b := bits.Sub32(x[0], y[0], 0) if b != 0 { - z[0] += q + t += q } + z[0] = t return z } From d022a641817f91f7b7a9967ec12d4561fcbafc84 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sat, 7 Dec 2024 23:36:33 +0000 Subject: [PATCH 40/74] feat: skeletton for vec assembly on F31 --- field/asm/element_1w_amd64.s | 6 ++ field/babybear/asm_adx.go | 26 ++++++++ field/babybear/asm_avx.go | 26 ++++++++ field/babybear/asm_noadx.go | 27 ++++++++ field/babybear/asm_noavx.go | 21 ++++++ field/babybear/element_amd64.go | 66 +++++++++++++++++++ field/babybear/element_amd64.s | 21 ++++++ field/babybear/vector_amd64.go | 56 ++++++++++++++++ field/babybear/vector_purego.go | 2 + field/generator/asm/amd64/build.go | 9 +++ field/generator/config/field_config.go | 4 +- field/generator/generator.go | 15 +++-- .../internal/templates/element/asm.go | 2 +- .../templates/element/vector_ops_asm.go | 39 +++++++++++ field/internal/main.go | 9 ++- field/koalabear/asm_avx.go | 26 ++++++++ field/koalabear/asm_noavx.go | 21 ++++++ field/koalabear/element_amd64.s | 21 ++++++ field/koalabear/vector_amd64.go | 56 ++++++++++++++++ field/koalabear/vector_purego.go | 2 + 20 files changed, 447 insertions(+), 8 deletions(-) create mode 100644 field/asm/element_1w_amd64.s create mode 100644 field/babybear/asm_adx.go create mode 100644 field/babybear/asm_avx.go create mode 100644 field/babybear/asm_noadx.go create mode 100644 field/babybear/asm_noavx.go create mode 100644 field/babybear/element_amd64.go create mode 100644 field/babybear/element_amd64.s create mode 100644 field/babybear/vector_amd64.go create mode 100644 field/koalabear/asm_avx.go create mode 100644 field/koalabear/asm_noavx.go create mode 100644 field/koalabear/element_amd64.s create mode 100644 field/koalabear/vector_amd64.go diff --git a/field/asm/element_1w_amd64.s b/field/asm/element_1w_amd64.s new file mode 100644 index 0000000000..10ae161550 --- /dev/null +++ b/field/asm/element_1w_amd64.s @@ -0,0 +1,6 @@ +// Code generated by gnark-crypto/generator. DO NOT EDIT. +#include "textflag.h" +#include "funcdata.h" +#include "go_asm.h" + +// TODO: implement F31 assembly code diff --git a/field/babybear/asm_adx.go b/field/babybear/asm_adx.go new file mode 100644 index 0000000000..33faad5ec2 --- /dev/null +++ b/field/babybear/asm_adx.go @@ -0,0 +1,26 @@ +//go:build !noadx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +import "golang.org/x/sys/cpu" + +var ( + supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 + _ = supportAdx +) diff --git a/field/babybear/asm_avx.go b/field/babybear/asm_avx.go new file mode 100644 index 0000000000..c46b1dca6d --- /dev/null +++ b/field/babybear/asm_avx.go @@ -0,0 +1,26 @@ +//go:build !noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +import "golang.org/x/sys/cpu" + +var ( + supportAvx512 = cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 +) diff --git a/field/babybear/asm_noadx.go b/field/babybear/asm_noadx.go new file mode 100644 index 0000000000..c01b8ba5dc --- /dev/null +++ b/field/babybear/asm_noadx.go @@ -0,0 +1,27 @@ +//go:build noadx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +// note: this is needed for test purposes, as dynamically changing supportAdx doesn't flag +// certain errors (like fatal error: missing stackmap) +// this ensures we test all asm path. +var ( + supportAdx = false + _ = supportAdx +) diff --git a/field/babybear/asm_noavx.go b/field/babybear/asm_noavx.go new file mode 100644 index 0000000000..1e4c0fd04b --- /dev/null +++ b/field/babybear/asm_noavx.go @@ -0,0 +1,21 @@ +//go:build noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +const supportAvx512 = false diff --git a/field/babybear/element_amd64.go b/field/babybear/element_amd64.go new file mode 100644 index 0000000000..9a0c54891f --- /dev/null +++ b/field/babybear/element_amd64.go @@ -0,0 +1,66 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +//go:noescape +func MulBy3(x *Element) + +//go:noescape +func MulBy5(x *Element) + +//go:noescape +func MulBy13(x *Element) + +//go:noescape +func mul(res, x, y *Element) + +//go:noescape +func fromMont(res *Element) + +//go:noescape +func reduce(res *Element) + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// +//go:noescape +func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/field/babybear/element_amd64.s b/field/babybear/element_amd64.s new file mode 100644 index 0000000000..86531c7112 --- /dev/null +++ b/field/babybear/element_amd64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 16364262440698776471 +#include "../asm/element_1w_amd64.s" + diff --git a/field/babybear/vector_amd64.go b/field/babybear/vector_amd64.go new file mode 100644 index 0000000000..cd1b7e5fc6 --- /dev/null +++ b/field/babybear/vector_amd64.go @@ -0,0 +1,56 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/field/babybear/vector_purego.go b/field/babybear/vector_purego.go index 8843280543..fd55d9406b 100644 --- a/field/babybear/vector_purego.go +++ b/field/babybear/vector_purego.go @@ -1,3 +1,5 @@ +//go:build purego || !amd64 + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index 40edcfc366..517b92bcce 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -275,6 +275,10 @@ func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { f.WriteLn("#include \"go_asm.h\"") f.WriteLn("") + if nbWords == 1 { + return GenerateF31ASM(f, hasVector) + } + f.GenerateReduceDefine() // reduce @@ -309,3 +313,8 @@ func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { return nil } + +func GenerateF31ASM(f *FFAmd64, hasVector bool) error { + f.Comment("TODO: implement F31 assembly code") + return nil +} diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index 2d992db463..26ee5fe5e7 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -299,8 +299,8 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) // note: to simplify output files generated, we generated ASM code only for // moduli that meet the condition F.NoCarry // asm code generation for moduli with more than 6 words can be optimized further - F.GenerateOpsAMD64 = F.NoCarry && F.NbWords <= 12 && F.NbWords > 1 - F.GenerateVectorOpsAMD64 = F.GenerateOpsAMD64 && F.NbWords == 4 && F.NbBits > 225 + F.GenerateOpsAMD64 = F.F31 || (F.NoCarry && F.NbWords <= 12 && F.NbWords > 1) + F.GenerateVectorOpsAMD64 = F.F31 || (F.GenerateOpsAMD64 && F.NbWords == 4 && F.NbBits > 225) F.GenerateOpsARM64 = F.GenerateOpsAMD64 && (F.NbWords%2 == 0) F.GenerateVectorOpsARM64 = false diff --git a/field/generator/generator.go b/field/generator/generator.go index 7547db305a..174c205c27 100644 --- a/field/generator/generator.go +++ b/field/generator/generator.go @@ -132,6 +132,7 @@ func GenerateFF(F *config.FieldConfig, outputDir, asmDirBuildPath, asmDirInclude } else if !F.GenerateOpsARM64 { pureGoBuildTag = "purego || (!amd64)" } + pureGoVectorBuildTag := "purego || (!amd64 && !arm64)" if !F.GenerateVectorOpsAMD64 && !F.GenerateVectorOpsARM64 { pureGoVectorBuildTag = "" @@ -139,6 +140,11 @@ func GenerateFF(F *config.FieldConfig, outputDir, asmDirBuildPath, asmDirInclude pureGoVectorBuildTag = "purego || (!amd64)" } + if F.F31 { + pureGoBuildTag = "" // always generate pure go for F31 + pureGoVectorBuildTag = "purego || (!amd64)" + } + var g errgroup.Group g.Go(generate("element.go", sourceFiles)) @@ -151,18 +157,19 @@ func GenerateFF(F *config.FieldConfig, outputDir, asmDirBuildPath, asmDirInclude g.Go(generate("element_amd64.s", []string{element.IncludeASM}, Only(F.GenerateOpsAMD64), WithBuildTag("!purego"), WithData(amd64d))) g.Go(generate("element_arm64.s", []string{element.IncludeASM}, Only(F.GenerateOpsARM64), WithBuildTag("!purego"), WithData(arm64d))) - g.Go(generate("element_amd64.go", []string{element.OpsAMD64, element.MulDoc}, Only(F.GenerateOpsAMD64), WithBuildTag("!purego"))) + g.Go(generate("element_amd64.go", []string{element.OpsAMD64, element.MulDoc}, Only(F.GenerateOpsAMD64 && !F.F31), WithBuildTag("!purego"))) g.Go(generate("element_arm64.go", []string{element.OpsARM64, element.MulNoCarry, element.Reduce}, Only(F.GenerateOpsARM64), WithBuildTag("!purego"))) g.Go(generate("element_purego.go", []string{element.OpsNoAsm, element.MulCIOS, element.MulNoCarry, element.Reduce, element.MulDoc}, WithBuildTag(pureGoBuildTag))) - g.Go(generate("vector_amd64.go", []string{element.VectorOpsAmd64}, Only(F.GenerateVectorOpsAMD64), WithBuildTag("!purego"))) + g.Go(generate("vector_amd64.go", []string{element.VectorOpsAmd64}, Only(F.GenerateVectorOpsAMD64 && !F.F31), WithBuildTag("!purego"))) + g.Go(generate("vector_amd64.go", []string{element.VectorOpsAmd64F31}, Only(F.GenerateVectorOpsAMD64 && F.F31), WithBuildTag("!purego"))) g.Go(generate("vector_arm64.go", []string{element.VectorOpsArm64}, Only(F.GenerateVectorOpsARM64), WithBuildTag("!purego"))) g.Go(generate("vector_purego.go", []string{element.VectorOpsPureGo}, WithBuildTag(pureGoVectorBuildTag))) - g.Go(generate("asm_adx.go", []string{element.Asm}, Only(F.GenerateOpsAMD64), WithBuildTag("!noadx"))) - g.Go(generate("asm_noadx.go", []string{element.AsmNoAdx}, Only(F.GenerateOpsAMD64), WithBuildTag("noadx"))) + g.Go(generate("asm_adx.go", []string{element.Asm}, Only(F.GenerateOpsAMD64 && !F.F31), WithBuildTag("!noadx"))) + g.Go(generate("asm_noadx.go", []string{element.AsmNoAdx}, Only(F.GenerateOpsAMD64 && !F.F31), WithBuildTag("noadx"))) g.Go(generate("asm_avx.go", []string{element.Avx}, Only(F.GenerateVectorOpsAMD64), WithBuildTag("!noavx"))) g.Go(generate("asm_noavx.go", []string{element.NoAvx}, Only(F.GenerateVectorOpsAMD64), WithBuildTag("noavx"))) diff --git a/field/generator/internal/templates/element/asm.go b/field/generator/internal/templates/element/asm.go index ed7fac7b40..53cadbc0a8 100644 --- a/field/generator/internal/templates/element/asm.go +++ b/field/generator/internal/templates/element/asm.go @@ -14,7 +14,7 @@ const Avx = ` import "golang.org/x/sys/cpu" var ( - supportAvx512 = supportAdx && cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + supportAvx512 = {{- if not .F31 }}supportAdx && {{- end}}cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ _ = supportAvx512 ) ` diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go index e4de4f1b27..c3189a014b 100644 --- a/field/generator/internal/templates/element/vector_ops_asm.go +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -145,3 +145,42 @@ func mulVec(res, a, b *{{.ElementName}}, n uint64, qInvNeg uint64) ` const VectorOpsArm64 = VectorOpsPureGo + +const VectorOpsAmd64F31 = ` +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res {{.ElementName}}) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res {{.ElementName}}) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} +` diff --git a/field/internal/main.go b/field/internal/main.go index b495d394e5..ce896f7fd3 100644 --- a/field/internal/main.go +++ b/field/internal/main.go @@ -23,12 +23,19 @@ func main() { {"babybear", "0x78000001"}, // 2^31 - 2^27 + 1 ==> 2-adicity 27 } + // generate assembly + asmDir := filepath.Join("..", "asm") + asmDirIncludePath := filepath.Join("..", "asm") + if err := generator.GenerateAMD64(1, asmDir, true); err != nil { + panic(err) + } + for _, f := range fields { fc, err := config.NewFieldConfig(f.name, "Element", f.modulus, true) if err != nil { panic(err) } - if err := generator.GenerateFF(fc, filepath.Join("..", f.name), "", ""); err != nil { + if err := generator.GenerateFF(fc, filepath.Join("..", f.name), asmDirIncludePath, asmDirIncludePath); err != nil { panic(err) } fmt.Println("successfully generated", f.name, "field") diff --git a/field/koalabear/asm_avx.go b/field/koalabear/asm_avx.go new file mode 100644 index 0000000000..cfc69e89d7 --- /dev/null +++ b/field/koalabear/asm_avx.go @@ -0,0 +1,26 @@ +//go:build !noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package koalabear + +import "golang.org/x/sys/cpu" + +var ( + supportAvx512 = cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 +) diff --git a/field/koalabear/asm_noavx.go b/field/koalabear/asm_noavx.go new file mode 100644 index 0000000000..099a5be133 --- /dev/null +++ b/field/koalabear/asm_noavx.go @@ -0,0 +1,21 @@ +//go:build noavx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package koalabear + +const supportAvx512 = false diff --git a/field/koalabear/element_amd64.s b/field/koalabear/element_amd64.s new file mode 100644 index 0000000000..86531c7112 --- /dev/null +++ b/field/koalabear/element_amd64.s @@ -0,0 +1,21 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 16364262440698776471 +#include "../asm/element_1w_amd64.s" + diff --git a/field/koalabear/vector_amd64.go b/field/koalabear/vector_amd64.go new file mode 100644 index 0000000000..4f994da4ae --- /dev/null +++ b/field/koalabear/vector_amd64.go @@ -0,0 +1,56 @@ +//go:build !purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package koalabear + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/field/koalabear/vector_purego.go b/field/koalabear/vector_purego.go index 71dc2cc0c3..09b63a3a68 100644 --- a/field/koalabear/vector_purego.go +++ b/field/koalabear/vector_purego.go @@ -1,3 +1,5 @@ +//go:build purego || !amd64 + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); From 9216305d72b4dcc1d52f1008be93589a7edfa146 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sat, 7 Dec 2024 23:38:51 +0000 Subject: [PATCH 41/74] refactor: rename asm generation code for 4 words --- field/generator/asm/amd64/build.go | 17 ++++++++++------- field/generator/asm/amd64/element_vec.go | 10 +++++----- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index 517b92bcce..c9e061cab6 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -303,18 +303,21 @@ func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { f.Comment("Vector operations are partially derived from Dag Arne Osvik's work in github.com/a16z/vectorized-fields") f.WriteLn("") - f.generateAddVec() - f.generateSubVec() - f.generateSumVec() - f.generateInnerProduct() - f.generateMulVec("scalarMulVec") - f.generateMulVec("mulVec") + f.generateAddVecW4() + f.generateSubVecW4() + f.generateSumVecW4() + f.generateInnerProductW4() + f.generateMulVecW4("scalarMulVec") + f.generateMulVecW4("mulVec") } return nil } func GenerateF31ASM(f *FFAmd64, hasVector bool) error { - f.Comment("TODO: implement F31 assembly code") + if !hasVector { + return nil // nothing for now. + } + return nil } diff --git a/field/generator/asm/amd64/element_vec.go b/field/generator/asm/amd64/element_vec.go index c67137d6cb..eab39f2eb7 100644 --- a/field/generator/asm/amd64/element_vec.go +++ b/field/generator/asm/amd64/element_vec.go @@ -23,7 +23,7 @@ import ( // addVec res = a + b // func addVec(res, a, b *{{.ElementName}}, n uint64) -func (f *FFAmd64) generateAddVec() { +func (f *FFAmd64) generateAddVecW4() { f.Comment("addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n]") const argSize = 4 * 8 @@ -85,7 +85,7 @@ func (f *FFAmd64) generateAddVec() { // subVec res = a - b // func subVec(res, a, b *{{.ElementName}}, n uint64) -func (f *FFAmd64) generateSubVec() { +func (f *FFAmd64) generateSubVecW4() { f.Comment("subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n]") const argSize = 4 * 8 @@ -160,7 +160,7 @@ func (f *FFAmd64) generateSubVec() { } // sumVec res = sum(a[0...n]) -func (f *FFAmd64) generateSumVec() { +func (f *FFAmd64) generateSumVecW4() { f.Comment("sumVec(res, a *Element, n uint64) res = sum(a[0...n])") const argSize = 3 * 8 @@ -452,7 +452,7 @@ func (f *FFAmd64) generateSumVec() { f.Push(®isters, w0l, w1l, w2l, w3l, w3h) } -func (f *FFAmd64) generateInnerProduct() { +func (f *FFAmd64) generateInnerProductW4() { f.Comment("innerProdVec(res, a,b *Element, n uint64) res = sum(a[0...n] * b[0...n])") const argSize = 4 * 8 @@ -822,7 +822,7 @@ func (f *FFAmd64) generateInnerProduct() { f.RET() } -func (f *FFAmd64) generateMulVec(funcName string) { +func (f *FFAmd64) generateMulVecW4(funcName string) { scalarMul := funcName != "mulVec" const argSize = 5 * 8 From deb1d7fcc19502c8692ec38816769dd5720e2a76 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sat, 7 Dec 2024 23:39:08 +0000 Subject: [PATCH 42/74] refactor: rename asm generation code for 4 words --- .../generator/asm/amd64/{element_vec.go => element_vec_4words.go} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename field/generator/asm/amd64/{element_vec.go => element_vec_4words.go} (100%) diff --git a/field/generator/asm/amd64/element_vec.go b/field/generator/asm/amd64/element_vec_4words.go similarity index 100% rename from field/generator/asm/amd64/element_vec.go rename to field/generator/asm/amd64/element_vec_4words.go From 8f230c27b6f7b27833213638d0676f09c78e60e6 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sun, 8 Dec 2024 00:15:18 +0000 Subject: [PATCH 43/74] feat: added F31 avx512 add --- field/asm/element_1w_amd64.s | 30 +++- field/babybear/element_amd64.s | 2 +- field/babybear/vector_amd64.go | 24 ++- field/generator/asm/amd64/build.go | 4 + field/generator/asm/amd64/element_vec_F31.go | 162 ++++++++++++++++++ .../templates/element/vector_ops_asm.go | 24 ++- field/koalabear/element_amd64.s | 2 +- field/koalabear/vector_amd64.go | 24 ++- go.mod | 2 +- go.sum | 6 + 10 files changed, 273 insertions(+), 7 deletions(-) create mode 100644 field/generator/asm/amd64/element_vec_F31.go diff --git a/field/asm/element_1w_amd64.s b/field/asm/element_1w_amd64.s index 10ae161550..ce733c0013 100644 --- a/field/asm/element_1w_amd64.s +++ b/field/asm/element_1w_amd64.s @@ -3,4 +3,32 @@ #include "funcdata.h" #include "go_asm.h" -// TODO: implement F31 assembly code +// Vector operations are partially derived from Plonky3 https://github.com/Plonky3/Plonky3 +// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] +TEXT ·addVec(SB), NOSPLIT, $0-32 + MOVD $const_q, AX + VPBROADCASTD AX, Z3 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + +loop_1: + TESTQ BX, BX + JEQ done_2 // n == 0, we are done + VMOVDQU32 0(AX), Z0 + VMOVDQU32 0(DX), Z1 + VPADDD Z0, Z1, Z0 + VPSUBD Z3, Z0, Z2 + VPMINUD Z0, Z2, Z1 + VMOVDQU32 Z1, 0(CX) + + // increment pointers to visit next element + ADDQ $64, AX + ADDQ $64, DX + ADDQ $64, CX + DECQ BX // decrement n + JMP loop_1 + +done_2: + RET diff --git a/field/babybear/element_amd64.s b/field/babybear/element_amd64.s index 86531c7112..10463fe14d 100644 --- a/field/babybear/element_amd64.s +++ b/field/babybear/element_amd64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 16364262440698776471 +// We include the hash to force the Go compiler to recompile: 929674669831530430 #include "../asm/element_1w_amd64.s" diff --git a/field/babybear/vector_amd64.go b/field/babybear/vector_amd64.go index cd1b7e5fc6..ce153a9dd4 100644 --- a/field/babybear/vector_amd64.go +++ b/field/babybear/vector_amd64.go @@ -21,9 +21,31 @@ package babybear // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + if !supportAvx512 { + // call addVecGeneric + addVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + addVec(&(*vector)[0], &a[0], &b[0], n/blockSize) + if n%blockSize != 0 { + // call addVecGeneric on the rest + start := n - n%blockSize + addVecGeneric((*vector)[start:], a[start:], b[start:]) + } } +//go:noescape +func addVec(res, a, b *Element, n uint64) + // Sub subtracts two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Sub(a, b Vector) { diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index c9e061cab6..719d7f92c8 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -319,5 +319,9 @@ func GenerateF31ASM(f *FFAmd64, hasVector bool) error { return nil // nothing for now. } + f.Comment("Vector operations are partially derived from Plonky3 https://github.com/Plonky3/Plonky3") + + f.generateAddVecF31() + return nil } diff --git a/field/generator/asm/amd64/element_vec_F31.go b/field/generator/asm/amd64/element_vec_F31.go new file mode 100644 index 0000000000..d14f8c406a --- /dev/null +++ b/field/generator/asm/amd64/element_vec_F31.go @@ -0,0 +1,162 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package amd64 + +import "github.com/consensys/bavard/amd64" + +// addVec res = a + b +// func addVec(res, a, b *{{.ElementName}}, n uint64) +func (f *FFAmd64) generateAddVecF31() { + f.Comment("addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n]") + + const argSize = 4 * 8 + stackSize := f.StackSize(f.NbWords*2+4, 0, 0) + registers := f.FnHeader("addVec", stackSize, argSize) + defer f.AssertCleanStack(stackSize, 0) + + // registers & labels we need + addrA := f.Pop(®isters) + addrB := f.Pop(®isters) + addrRes := f.Pop(®isters) + len := f.Pop(®isters) + + // AVX512 registers + a := amd64.Register("Z0") + b := amd64.Register("Z1") + t := amd64.Register("Z2") + q := amd64.Register("Z3") + + // load q in Z3 + f.WriteLn("MOVD $const_q, AX") + f.VPBROADCASTD("AX", q) + + loop := f.NewLabel("loop") + done := f.NewLabel("done") + + // load arguments + f.MOVQ("res+0(FP)", addrRes) + f.MOVQ("a+8(FP)", addrA) + f.MOVQ("b+16(FP)", addrB) + f.MOVQ("n+24(FP)", len) + + f.LABEL(loop) + + f.TESTQ(len, len) + f.JEQ(done, "n == 0, we are done") + + // a = a + b + f.VMOVDQU32(addrA.At(0), a) + f.VMOVDQU32(addrB.At(0), b) + f.VPADDD(a, b, a) + + // t = a - q + f.VPSUBD(q, a, t) + + // b = min(t, a) + f.VPMINUD(a, t, b) + + // move b to res + f.VMOVDQU32(b, addrRes.At(0)) + + f.Comment("increment pointers to visit next element") + f.ADDQ("$64", addrA) + f.ADDQ("$64", addrB) + f.ADDQ("$64", addrRes) + f.DECQ(len, "decrement n") + f.JMP(loop) + + f.LABEL(done) + + f.RET() + + f.Push(®isters, addrA, addrB, addrRes, len) + +} + +// // subVec res = a - b +// // func subVec(res, a, b *{{.ElementName}}, n uint64) +// func (f *FFAmd64) generateSubVecW4() { +// f.Comment("subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n]") + +// const argSize = 4 * 8 +// stackSize := f.StackSize(f.NbWords*2+5, 0, 0) +// registers := f.FnHeader("subVec", stackSize, argSize) +// defer f.AssertCleanStack(stackSize, 0) + +// // registers +// addrA := f.Pop(®isters) +// addrB := f.Pop(®isters) +// addrRes := f.Pop(®isters) +// len := f.Pop(®isters) +// zero := f.Pop(®isters) + +// a := f.PopN(®isters) +// q := f.PopN(®isters) + +// loop := f.NewLabel("loop") +// done := f.NewLabel("done") + +// // load arguments +// f.MOVQ("res+0(FP)", addrRes) +// f.MOVQ("a+8(FP)", addrA) +// f.MOVQ("b+16(FP)", addrB) +// f.MOVQ("n+24(FP)", len) + +// f.XORQ(zero, zero) + +// f.LABEL(loop) + +// f.TESTQ(len, len) +// f.JEQ(done, "n == 0, we are done") + +// // a = a - b +// f.LabelRegisters("a", a...) +// f.Mov(addrA, a) +// f.Sub(addrB, a) +// f.WriteLn(fmt.Sprintf("PREFETCHT0 2048(%[1]s)", addrA)) +// f.WriteLn(fmt.Sprintf("PREFETCHT0 2048(%[1]s)", addrB)) + +// // reduce a +// f.Comment("reduce (a-b) mod q") +// f.LabelRegisters("q", q...) +// for i := 0; i < f.NbWords; i++ { +// f.MOVQ(fmt.Sprintf("$const_q%d", i), q[i]) +// } +// for i := 0; i < f.NbWords; i++ { +// f.CMOVQCC(zero, q[i]) +// } +// // add registers (q or 0) to a, and set to result +// f.Comment("add registers (q or 0) to a, and set to result") +// f.Add(q, a) + +// // save a into res +// f.Mov(a, addrRes) + +// f.Comment("increment pointers to visit next element") +// f.ADDQ("$32", addrA) +// f.ADDQ("$32", addrB) +// f.ADDQ("$32", addrRes) +// f.DECQ(len, "decrement n") +// f.JMP(loop) + +// f.LABEL(done) + +// f.RET() + +// f.Push(®isters, a...) +// f.Push(®isters, q...) +// f.Push(®isters, addrA, addrB, addrRes, len, zero) + +// } diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go index c3189a014b..d4b86971f4 100644 --- a/field/generator/internal/templates/element/vector_ops_asm.go +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -150,9 +150,31 @@ const VectorOpsAmd64F31 = ` // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + if !supportAvx512 { + // call addVecGeneric + addVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + addVec(&(*vector)[0], &a[0], &b[0], n/blockSize) + if n % blockSize != 0 { + // call addVecGeneric on the rest + start := n - n % blockSize + addVecGeneric((*vector)[start:], a[start:], b[start:]) + } } +//go:noescape +func addVec(res, a, b *{{.ElementName}}, n uint64) + // Sub subtracts two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Sub(a, b Vector) { diff --git a/field/koalabear/element_amd64.s b/field/koalabear/element_amd64.s index 86531c7112..10463fe14d 100644 --- a/field/koalabear/element_amd64.s +++ b/field/koalabear/element_amd64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 16364262440698776471 +// We include the hash to force the Go compiler to recompile: 929674669831530430 #include "../asm/element_1w_amd64.s" diff --git a/field/koalabear/vector_amd64.go b/field/koalabear/vector_amd64.go index 4f994da4ae..de48e05229 100644 --- a/field/koalabear/vector_amd64.go +++ b/field/koalabear/vector_amd64.go @@ -21,9 +21,31 @@ package koalabear // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + if !supportAvx512 { + // call addVecGeneric + addVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + addVec(&(*vector)[0], &a[0], &b[0], n/blockSize) + if n%blockSize != 0 { + // call addVecGeneric on the rest + start := n - n%blockSize + addVecGeneric((*vector)[start:], a[start:], b[start:]) + } } +//go:noescape +func addVec(res, a, b *Element, n uint64) + // Sub subtracts two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Sub(a, b Vector) { diff --git a/go.mod b/go.mod index c4dbc8bc84..45e3c85e9e 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/bits-and-blooms/bitset v1.14.2 - github.com/consensys/bavard v0.1.23-0.20241022191302-a6fdcdb6e8f3 + github.com/consensys/bavard v0.1.23-0.20241208000453-1b3c9246dcd6 github.com/leanovate/gopter v0.2.11 github.com/mmcloughlin/addchain v0.4.0 github.com/spf13/cobra v1.8.1 diff --git a/go.sum b/go.sum index af73f869f3..5adf0a8283 100644 --- a/go.sum +++ b/go.sum @@ -65,6 +65,12 @@ github.com/consensys/bavard v0.1.23-0.20241022191117-d73e50a886cc h1:NwWCvGXSPH8 github.com/consensys/bavard v0.1.23-0.20241022191117-d73e50a886cc/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/bavard v0.1.23-0.20241022191302-a6fdcdb6e8f3 h1:8gPxbjhwhxXTakOXII32eLlAFLlYImoENa3uQ6iP+go= github.com/consensys/bavard v0.1.23-0.20241022191302-a6fdcdb6e8f3/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= +github.com/consensys/bavard v0.1.23-0.20241207235124-babad3045f79 h1:lhIivWq5SgulQUNtgUugSMqcIpQNZkB5EPD/CwF3r9w= +github.com/consensys/bavard v0.1.23-0.20241207235124-babad3045f79/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= +github.com/consensys/bavard v0.1.23-0.20241207235803-84aa6b3d4724 h1:wBPDHYgf1QvlnW/7gVZVBYVgkKjYV1J8Hbsa5qwvESs= +github.com/consensys/bavard v0.1.23-0.20241207235803-84aa6b3d4724/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= +github.com/consensys/bavard v0.1.23-0.20241208000453-1b3c9246dcd6 h1:dm/VT++/p4tq8FLR/8z361AvWPD9dcp6xXebPLaEdZo= +github.com/consensys/bavard v0.1.23-0.20241208000453-1b3c9246dcd6/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= From d5a6b4d11ddeafb9a1145d8f8a26aa8819dd5a50 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sun, 8 Dec 2024 00:27:57 +0000 Subject: [PATCH 44/74] feat: add avx512 sub for f31 --- field/asm/element_1w_amd64.s | 29 ++++++++ field/babybear/element_amd64.s | 2 +- field/babybear/vector_amd64.go | 24 +++++- field/generator/asm/amd64/build.go | 1 + field/generator/asm/amd64/element_vec_F31.go | 74 ++++++++++++++++++- .../templates/element/vector_ops_asm.go | 24 +++++- field/koalabear/element_amd64.s | 2 +- field/koalabear/vector_amd64.go | 24 +++++- 8 files changed, 173 insertions(+), 7 deletions(-) diff --git a/field/asm/element_1w_amd64.s b/field/asm/element_1w_amd64.s index ce733c0013..d0db310d6a 100644 --- a/field/asm/element_1w_amd64.s +++ b/field/asm/element_1w_amd64.s @@ -32,3 +32,32 @@ loop_1: done_2: RET + +// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] +TEXT ·subVec(SB), NOSPLIT, $0-32 + MOVD $const_q, AX + VPBROADCASTD AX, Z3 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + +loop_3: + TESTQ BX, BX + JEQ done_4 // n == 0, we are done + VMOVDQU32 0(AX), Z0 + VMOVDQU32 0(DX), Z1 + VPSUBD Z1, Z0, Z0 + VPADDD Z3, Z0, Z2 + VPMINUD Z0, Z2, Z1 + VMOVDQU32 Z1, 0(CX) + + // increment pointers to visit next element + ADDQ $64, AX + ADDQ $64, DX + ADDQ $64, CX + DECQ BX // decrement n + JMP loop_3 + +done_4: + RET diff --git a/field/babybear/element_amd64.s b/field/babybear/element_amd64.s index 10463fe14d..8a1c0ab405 100644 --- a/field/babybear/element_amd64.s +++ b/field/babybear/element_amd64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 929674669831530430 +// We include the hash to force the Go compiler to recompile: 4819854176290243046 #include "../asm/element_1w_amd64.s" diff --git a/field/babybear/vector_amd64.go b/field/babybear/vector_amd64.go index ce153a9dd4..ccfaf2e303 100644 --- a/field/babybear/vector_amd64.go +++ b/field/babybear/vector_amd64.go @@ -49,9 +49,31 @@ func addVec(res, a, b *Element, n uint64) // Sub subtracts two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + if !supportAvx512 { + // call subVecGeneric + subVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + subVec(&(*vector)[0], &a[0], &b[0], n/blockSize) + if n%blockSize != 0 { + // call subVecGeneric on the rest + start := n - n%blockSize + subVecGeneric((*vector)[start:], a[start:], b[start:]) + } } +//go:noescape +func subVec(res, a, b *Element, n uint64) + // ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) ScalarMul(a Vector, b *Element) { diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index 719d7f92c8..8e71b424e2 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -322,6 +322,7 @@ func GenerateF31ASM(f *FFAmd64, hasVector bool) error { f.Comment("Vector operations are partially derived from Plonky3 https://github.com/Plonky3/Plonky3") f.generateAddVecF31() + f.generateSubVecF31() return nil } diff --git a/field/generator/asm/amd64/element_vec_F31.go b/field/generator/asm/amd64/element_vec_F31.go index d14f8c406a..c513a3d818 100644 --- a/field/generator/asm/amd64/element_vec_F31.go +++ b/field/generator/asm/amd64/element_vec_F31.go @@ -14,7 +14,9 @@ package amd64 -import "github.com/consensys/bavard/amd64" +import ( + "github.com/consensys/bavard/amd64" +) // addVec res = a + b // func addVec(res, a, b *{{.ElementName}}, n uint64) @@ -60,9 +62,77 @@ func (f *FFAmd64) generateAddVecF31() { f.VMOVDQU32(addrA.At(0), a) f.VMOVDQU32(addrB.At(0), b) f.VPADDD(a, b, a) - // t = a - q f.VPSUBD(q, a, t) + // b = min(t, a) + f.VPMINUD(a, t, b) + + // move b to res + f.VMOVDQU32(b, addrRes.At(0)) + + f.Comment("increment pointers to visit next element") + f.ADDQ("$64", addrA) + f.ADDQ("$64", addrB) + f.ADDQ("$64", addrRes) + f.DECQ(len, "decrement n") + f.JMP(loop) + + f.LABEL(done) + + f.RET() + + f.Push(®isters, addrA, addrB, addrRes, len) + +} + +// subVec res = a - b +// func subVec(res, a, b *{{.ElementName}}, n uint64) +func (f *FFAmd64) generateSubVecF31() { + f.Comment("subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n]") + + const argSize = 4 * 8 + stackSize := f.StackSize(f.NbWords*2+4, 0, 0) + registers := f.FnHeader("subVec", stackSize, argSize) + defer f.AssertCleanStack(stackSize, 0) + + // registers & labels we need + addrA := f.Pop(®isters) + addrB := f.Pop(®isters) + addrRes := f.Pop(®isters) + len := f.Pop(®isters) + + // AVX512 registers + a := amd64.Register("Z0") + b := amd64.Register("Z1") + t := amd64.Register("Z2") + q := amd64.Register("Z3") + + // load q in Z3 + f.WriteLn("MOVD $const_q, AX") + f.VPBROADCASTD("AX", q) + + loop := f.NewLabel("loop") + done := f.NewLabel("done") + + // load arguments + f.MOVQ("res+0(FP)", addrRes) + f.MOVQ("a+8(FP)", addrA) + f.MOVQ("b+16(FP)", addrB) + f.MOVQ("n+24(FP)", len) + + f.LABEL(loop) + + f.TESTQ(len, len) + f.JEQ(done, "n == 0, we are done") + + // a = a - b + f.VMOVDQU32(addrA.At(0), a) + f.VMOVDQU32(addrB.At(0), b) + + f.VPSUBD(b, a, a) + + // t = a + q + f.VPADDD(q, a, t) // b = min(t, a) f.VPMINUD(a, t, b) diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go index d4b86971f4..5f04566dea 100644 --- a/field/generator/internal/templates/element/vector_ops_asm.go +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -178,9 +178,31 @@ func addVec(res, a, b *{{.ElementName}}, n uint64) // Sub subtracts two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + if !supportAvx512 { + // call subVecGeneric + subVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + subVec(&(*vector)[0], &a[0], &b[0], n/blockSize) + if n % blockSize != 0 { + // call subVecGeneric on the rest + start := n - n % blockSize + subVecGeneric((*vector)[start:], a[start:], b[start:]) + } } +//go:noescape +func subVec(res, a, b *{{.ElementName}}, n uint64) + // ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { diff --git a/field/koalabear/element_amd64.s b/field/koalabear/element_amd64.s index 10463fe14d..8a1c0ab405 100644 --- a/field/koalabear/element_amd64.s +++ b/field/koalabear/element_amd64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 929674669831530430 +// We include the hash to force the Go compiler to recompile: 4819854176290243046 #include "../asm/element_1w_amd64.s" diff --git a/field/koalabear/vector_amd64.go b/field/koalabear/vector_amd64.go index de48e05229..1b87342f6a 100644 --- a/field/koalabear/vector_amd64.go +++ b/field/koalabear/vector_amd64.go @@ -49,9 +49,31 @@ func addVec(res, a, b *Element, n uint64) // Sub subtracts two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + if !supportAvx512 { + // call subVecGeneric + subVecGeneric(*vector, a, b) + return + } + + const blockSize = 16 + subVec(&(*vector)[0], &a[0], &b[0], n/blockSize) + if n%blockSize != 0 { + // call subVecGeneric on the rest + start := n - n%blockSize + subVecGeneric((*vector)[start:], a[start:], b[start:]) + } } +//go:noescape +func subVec(res, a, b *Element, n uint64) + // ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) ScalarMul(a Vector, b *Element) { From cf8370ddd7e0b4f2873f7936451e7546ec7c56b0 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sun, 8 Dec 2024 03:21:47 +0000 Subject: [PATCH 45/74] feat: added avx512 sum for f31 --- field/asm/element_1w_amd64.s | 32 +++++++++ field/babybear/element_amd64.s | 2 +- field/babybear/vector_amd64.go | 30 ++++++++- field/generator/asm/amd64/build.go | 1 + field/generator/asm/amd64/element_vec_F31.go | 66 +++++++++++++++++++ .../templates/element/vector_ops_asm.go | 30 ++++++++- field/koalabear/element_amd64.s | 2 +- field/koalabear/vector_amd64.go | 30 ++++++++- 8 files changed, 188 insertions(+), 5 deletions(-) diff --git a/field/asm/element_1w_amd64.s b/field/asm/element_1w_amd64.s index d0db310d6a..07ac2a512a 100644 --- a/field/asm/element_1w_amd64.s +++ b/field/asm/element_1w_amd64.s @@ -61,3 +61,35 @@ loop_3: done_4: RET + +// sumVec(res *uint64, a *[]uint32, n uint64) res = sum(a[0...n]) + +// We are load 8 31bits values at a time and accumulate them into an accumulator of +// 8 quadwords (64bits). The caller then needs to reduce the result mod q. +// We can safely accumulate ~2**33 31bits values into a single accumulator. +// That gives us a maximum of 2**33 * 8 = 2**36 31bits values to sum safely. + +TEXT ·sumVec(SB), NOSPLIT, $0-24 + MOVQ t+0(FP), R15 + MOVQ a+8(FP), R14 + MOVQ n+16(FP), CX + VXORPS Z2, Z2, Z2 + VMOVDQA64 Z2, Z3 + +loop_5: + TESTQ CX, CX + JEQ done_6 // n == 0, we are done + VPMOVZXDQ 0(R14), Z0 + VPMOVZXDQ 32(R14), Z1 + VPADDQ Z0, Z2, Z2 + VPADDQ Z1, Z3, Z3 + + // increment pointers to visit next element + ADDQ $64, R14 + DECQ CX // decrement n + JMP loop_5 + +done_6: + VPADDQ Z2, Z3, Z2 + VMOVDQU64 Z2, 0(R15) + RET diff --git a/field/babybear/element_amd64.s b/field/babybear/element_amd64.s index 8a1c0ab405..c52db5c441 100644 --- a/field/babybear/element_amd64.s +++ b/field/babybear/element_amd64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 4819854176290243046 +// We include the hash to force the Go compiler to recompile: 15742393007638587690 #include "../asm/element_1w_amd64.s" diff --git a/field/babybear/vector_amd64.go b/field/babybear/vector_amd64.go index ccfaf2e303..7004445268 100644 --- a/field/babybear/vector_amd64.go +++ b/field/babybear/vector_amd64.go @@ -82,10 +82,38 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { // Sum computes the sum of all elements in the vector. func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) + n := uint64(len(*vector)) + if n == 0 { + return + } + if !supportAvx512 { + // call sumVecGeneric + sumVecGeneric(&res, *vector) + return + } + + const blockSize = 16 + var t [8]uint64 // stores the accumulators (not reduced mod q) + sumVec(&t[0], &(*vector)[0], n/blockSize) + // we reduce the accumulators mod q and add to res + var v Element + for i := 0; i < 8; i++ { + t[i] %= q + v[0] = uint32(t[i]) + res.Add(&res, &v) + } + if n%blockSize != 0 { + // call sumVecGeneric on the rest + start := n - n%blockSize + sumVecGeneric(&res, (*vector)[start:]) + } + return } +//go:noescape +func sumVec(t *uint64, a *Element, n uint64) + // InnerProduct computes the inner product of two vectors. // It panics if the vectors don't have the same length. func (vector *Vector) InnerProduct(other Vector) (res Element) { diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index 8e71b424e2..936e88836b 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -323,6 +323,7 @@ func GenerateF31ASM(f *FFAmd64, hasVector bool) error { f.generateAddVecF31() f.generateSubVecF31() + f.generateSumVecF31() return nil } diff --git a/field/generator/asm/amd64/element_vec_F31.go b/field/generator/asm/amd64/element_vec_F31.go index c513a3d818..5212b93950 100644 --- a/field/generator/asm/amd64/element_vec_F31.go +++ b/field/generator/asm/amd64/element_vec_F31.go @@ -155,6 +155,72 @@ func (f *FFAmd64) generateSubVecF31() { } +// sumVec res = sum(a[0...n]) +func (f *FFAmd64) generateSumVecF31() { + f.Comment("sumVec(res *uint64, a *[]uint32, n uint64) res = sum(a[0...n])") + f.WriteLn(` + // We are load 8 31bits values at a time and accumulate them into an accumulator of + // 8 quadwords (64bits). The caller then needs to reduce the result mod q. + // We can safely accumulate ~2**33 31bits values into a single accumulator. + // That gives us a maximum of 2**33 * 8 = 2**36 31bits values to sum safely. + `) + + const argSize = 3 * 8 + stackSize := f.StackSize(f.NbWords*3+2, 0, 0) + registers := f.FnHeader("sumVec", stackSize, argSize, amd64.DX, amd64.AX) + defer f.AssertCleanStack(stackSize, 0) + + // registers & labels we need + addrA := f.Pop(®isters) + addrT := f.Pop(®isters) + len := f.Pop(®isters) + + // AVX512 registers + a1 := amd64.Register("Z0") + a2 := amd64.Register("Z1") + acc1 := amd64.Register("Z2") + acc2 := amd64.Register("Z3") + + loop := f.NewLabel("loop") + done := f.NewLabel("done") + + // load arguments + f.MOVQ("t+0(FP)", addrT) + f.MOVQ("a+8(FP)", addrA) + f.MOVQ("n+16(FP)", len) + + // zeroize the accumulators + f.VXORPS(acc1, acc1, acc1) + f.VMOVDQA64(acc1, acc2) + + f.LABEL(loop) + + f.TESTQ(len, len) + f.JEQ(done, "n == 0, we are done") + + // 1 cache line is typically 64 bytes, so we maintain 2 accumulators + f.VPMOVZXDQ(addrA.At(0), a1) + f.VPMOVZXDQ(addrA.At(4), a2) + + f.VPADDQ(a1, acc1, acc1) + f.VPADDQ(a2, acc2, acc2) + + f.Comment("increment pointers to visit next element") + f.ADDQ("$64", addrA) + f.DECQ(len, "decrement n") + f.JMP(loop) + + f.LABEL(done) + + // store t into res + f.VPADDQ(acc1, acc2, acc1) + f.VMOVDQU64(acc1, addrT.At(0)) + + f.RET() + + f.Push(®isters, addrA, addrT, len) +} + // // subVec res = a - b // // func subVec(res, a, b *{{.ElementName}}, n uint64) // func (f *FFAmd64) generateSubVecW4() { diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go index 5f04566dea..5c1b5115f1 100644 --- a/field/generator/internal/templates/element/vector_ops_asm.go +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -211,10 +211,38 @@ func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { // Sum computes the sum of all elements in the vector. func (vector *Vector) Sum() (res {{.ElementName}}) { - sumVecGeneric(&res, *vector) + n := uint64(len(*vector)) + if n == 0 { + return + } + if !supportAvx512 { + // call sumVecGeneric + sumVecGeneric(&res, *vector) + return + } + + const blockSize = 16 + var t [8]uint64 // stores the accumulators (not reduced mod q) + sumVec(&t[0], &(*vector)[0], n/blockSize) + // we reduce the accumulators mod q and add to res + var v {{.ElementName}} + for i := 0; i < 8; i++ { + t[i] %= q + v[0] = uint32(t[i]) + res.Add(&res, &v) + } + if n % blockSize != 0 { + // call sumVecGeneric on the rest + start := n - n % blockSize + sumVecGeneric(&res, (*vector)[start:]) + } + return } +//go:noescape +func sumVec(t *uint64, a *{{.ElementName}}, n uint64) + // InnerProduct computes the inner product of two vectors. // It panics if the vectors don't have the same length. func (vector *Vector) InnerProduct(other Vector) (res {{.ElementName}}) { diff --git a/field/koalabear/element_amd64.s b/field/koalabear/element_amd64.s index 8a1c0ab405..c52db5c441 100644 --- a/field/koalabear/element_amd64.s +++ b/field/koalabear/element_amd64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 4819854176290243046 +// We include the hash to force the Go compiler to recompile: 15742393007638587690 #include "../asm/element_1w_amd64.s" diff --git a/field/koalabear/vector_amd64.go b/field/koalabear/vector_amd64.go index 1b87342f6a..5a53382ed7 100644 --- a/field/koalabear/vector_amd64.go +++ b/field/koalabear/vector_amd64.go @@ -82,10 +82,38 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { // Sum computes the sum of all elements in the vector. func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) + n := uint64(len(*vector)) + if n == 0 { + return + } + if !supportAvx512 { + // call sumVecGeneric + sumVecGeneric(&res, *vector) + return + } + + const blockSize = 16 + var t [8]uint64 // stores the accumulators (not reduced mod q) + sumVec(&t[0], &(*vector)[0], n/blockSize) + // we reduce the accumulators mod q and add to res + var v Element + for i := 0; i < 8; i++ { + t[i] %= q + v[0] = uint32(t[i]) + res.Add(&res, &v) + } + if n%blockSize != 0 { + // call sumVecGeneric on the rest + start := n - n%blockSize + sumVecGeneric(&res, (*vector)[start:]) + } + return } +//go:noescape +func sumVec(t *uint64, a *Element, n uint64) + // InnerProduct computes the inner product of two vectors. // It panics if the vectors don't have the same length. func (vector *Vector) InnerProduct(other Vector) (res Element) { From 8289d3d25726bdfe7950476caa113e86582f9a19 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sun, 8 Dec 2024 05:02:05 +0000 Subject: [PATCH 46/74] feat: working version of the mul, optims to come --- field/asm/element_1w_amd64.s | 41 ++++++++ field/babybear/element_amd64.s | 2 +- field/babybear/vector_amd64.go | 24 ++++- field/generator/asm/amd64/build.go | 1 + field/generator/asm/amd64/element_vec_F31.go | 99 +++++++++++++++++++ field/generator/config/field_config.go | 20 ++++ .../templates/element/vector_ops_asm.go | 25 ++++- field/koalabear/element_amd64.s | 2 +- field/koalabear/vector_amd64.go | 24 ++++- go.mod | 2 +- go.sum | 4 + 11 files changed, 238 insertions(+), 6 deletions(-) diff --git a/field/asm/element_1w_amd64.s b/field/asm/element_1w_amd64.s index 07ac2a512a..091d00aa18 100644 --- a/field/asm/element_1w_amd64.s +++ b/field/asm/element_1w_amd64.s @@ -93,3 +93,44 @@ done_6: VPADDQ Z2, Z3, Z2 VMOVDQU64 Z2, 0(R15) RET + +// mulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b[0...n] +TEXT ·mulVec(SB), NOSPLIT, $0-32 + MOVD $const_q, AX + VPBROADCASTQ AX, Z3 + MOVD $const_qInvNeg, AX + VPBROADCASTQ AX, Z4 + + // Create mask for low dword in each qword + VPCMPEQB Y0, Y0, Y0 + VPMOVZXDQ Y0, Z6 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + +loop_7: + TESTQ BX, BX + JEQ done_8 // n == 0, we are done + VPMOVZXDQ 0(AX), Z0 + VPMOVZXDQ 0(DX), Z1 + VPMULUDQ Z0, Z1, Z2 + VPANDQ Z6, Z2, Z5 + VPMULUDQ Z5, Z4, Z5 + VPANDQ Z6, Z5, Z5 + VPMULUDQ Z5, Z3, Z5 + VPADDQ Z2, Z5, Z2 + VPSRLQ $32, Z2, Z2 + VPSUBD Z3, Z2, Z5 + VPMINUD Z2, Z5, Z2 + VPMOVQD Z2, 0(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_7 + +done_8: + RET diff --git a/field/babybear/element_amd64.s b/field/babybear/element_amd64.s index c52db5c441..da45886f9c 100644 --- a/field/babybear/element_amd64.s +++ b/field/babybear/element_amd64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 15742393007638587690 +// We include the hash to force the Go compiler to recompile: 15784575404660199304 #include "../asm/element_1w_amd64.s" diff --git a/field/babybear/vector_amd64.go b/field/babybear/vector_amd64.go index 7004445268..650e02eee7 100644 --- a/field/babybear/vector_amd64.go +++ b/field/babybear/vector_amd64.go @@ -124,5 +124,27 @@ func (vector *Vector) InnerProduct(other Vector) (res Element) { // Mul multiplies two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Mul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + if !supportAvx512 { + // call mulVecGeneric + mulVecGeneric(*vector, a, b) + return + } + + const blockSize = 8 + mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize) + if n%blockSize != 0 { + // call mulVecGeneric on the rest + start := n - n%blockSize + mulVecGeneric((*vector)[start:], a[start:], b[start:]) + } } + +//go:noescape +func mulVec(res, a, b *Element, n uint64) diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index 936e88836b..7118ff9178 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -324,6 +324,7 @@ func GenerateF31ASM(f *FFAmd64, hasVector bool) error { f.generateAddVecF31() f.generateSubVecF31() f.generateSumVecF31() + f.generateMulVecF31() return nil } diff --git a/field/generator/asm/amd64/element_vec_F31.go b/field/generator/asm/amd64/element_vec_F31.go index 5212b93950..c9d006af74 100644 --- a/field/generator/asm/amd64/element_vec_F31.go +++ b/field/generator/asm/amd64/element_vec_F31.go @@ -15,6 +15,8 @@ package amd64 import ( + "fmt" + "github.com/consensys/bavard/amd64" ) @@ -221,6 +223,103 @@ func (f *FFAmd64) generateSumVecF31() { f.Push(®isters, addrA, addrT, len) } +// mulVec res = a * b +func (f *FFAmd64) generateMulVecF31() { + f.Comment("mulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b[0...n]") + + const argSize = 4 * 8 + stackSize := f.StackSize(f.NbWords*2+4, 0, 0) + registers := f.FnHeader("mulVec", stackSize, argSize) + defer f.AssertCleanStack(stackSize, 0) + + // registers & labels we need + addrA := f.Pop(®isters) + addrB := f.Pop(®isters) + addrRes := f.Pop(®isters) + len := f.Pop(®isters) + + // AVX512 registers + a := amd64.Register("Z0") + b := amd64.Register("Z1") + P := amd64.Register("Z2") + q := amd64.Register("Z3") + qInvNeg := amd64.Register("Z4") + PL := amd64.Register("Z5") + LSW := amd64.Register("Z6") + + // load q in Z3 + f.WriteLn("MOVD $const_q, AX") + f.VPBROADCASTQ("AX", q) + f.WriteLn("MOVD $const_qInvNeg, AX") + f.VPBROADCASTQ("AX", qInvNeg) + + f.Comment("Create mask for low dword in each qword") + f.VPCMPEQB("Y0", "Y0", "Y0") + f.VPMOVZXDQ("Y0", LSW) + + loop := f.NewLabel("loop") + done := f.NewLabel("done") + + // load arguments + f.MOVQ("res+0(FP)", addrRes) + f.MOVQ("a+8(FP)", addrA) + f.MOVQ("b+16(FP)", addrB) + f.MOVQ("n+24(FP)", len) + + f.LABEL(loop) + + f.TESTQ(len, len) + f.JEQ(done, "n == 0, we are done") + + // a = a * b + f.VPMOVZXDQ(addrA.At(0), a) + f.VPMOVZXDQ(addrB.At(0), b) + f.VPMULUDQ(a, b, P) + // f.VPSRLQ("$32", P, PH) + f.VPANDQ(LSW, P, PL) // low dword + // m := uint32(v) * qInvNeg --> m = PL * qInvNeg + f.VPMULUDQ(PL, qInvNeg, PL) + f.VPANDQ(LSW, PL, PL) // mod R --> keep low dword + // m*=q + f.VPMULUDQ(PL, q, PL) + // add P + f.VPADDQ(P, PL, P) + f.VPSRLQ("$32", P, P) // shift right by 32 bits + + // now we need to use min to reduce + // first sub q from P + f.VPSUBD(q, P, PL) + + // res = min(P, PL) + f.VPMINUD(P, PL, P) + + // move P to res + f.WriteLn(fmt.Sprintf("VPMOVQD %s, %s", P, addrRes.At(0))) + // f.VMOVDQU32(P, addrRes.At(0)) + + // now we need to montReduce + + // // a = a - b + // f.VMOVDQU32(addrA.At(0), a) + // f.VMOVDQU32(addrB.At(0), b) + + // f.VPSUBD(b, a, a) + + f.Comment("increment pointers to visit next element") + f.ADDQ("$32", addrA) + f.ADDQ("$32", addrB) + f.ADDQ("$32", addrRes) + f.DECQ(len, "decrement n") + f.JMP(loop) + + f.LABEL(done) + + f.RET() + + f.Push(®isters, addrA, addrB, addrRes, len) + +} + // // subVec res = a - b // // func subVec(res, a, b *{{.ElementName}}, n uint64) // func (f *FFAmd64) generateSubVecW4() { diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index 26ee5fe5e7..ec0a7139d3 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -313,6 +313,26 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) F.Mu = muSlice[0] } + // We define MONTY_MU = PRIME^-1 (mod 2^MONTY_BITS). This is different from the usual convention + // (MONTY_MU = -PRIME^-1 (mod 2^MONTY_BITS)) but it avoids a carry. + // 2164260865 + if F.F31 { + // _mu := big.NewInt(0) + // _mu.Set(&bModulus) + // _mu.Neg(_mu) + // _mu.ModInverse(_mu, big.NewInt(1<<31)) + // muSlice := toUint64Slice(_mu, F.NbWords) + // F.Mu = muSlice[0] + _r := big.NewInt(1) + _r.Lsh(_r, uint(F.NbWords)*radix) + _rInv := big.NewInt(1) + _qInv := big.NewInt(0) + extendedEuclideanAlgo(_r, &bModulus, _rInv, _qInv) + _qInv.Neg(_qInv) + _qInv.Mod(_qInv, _r) + F.Mu = toUint64Slice(_qInv, F.NbWords)[0] + } + return F, nil } diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go index 5c1b5115f1..917415e500 100644 --- a/field/generator/internal/templates/element/vector_ops_asm.go +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -253,6 +253,29 @@ func (vector *Vector) InnerProduct(other Vector) (res {{.ElementName}}) { // Mul multiplies two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Mul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + if !supportAvx512 { + // call mulVecGeneric + mulVecGeneric(*vector, a, b) + return + } + + const blockSize = 8 + mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize) + if n % blockSize != 0 { + // call mulVecGeneric on the rest + start := n - n % blockSize + mulVecGeneric((*vector)[start:], a[start:], b[start:]) + } } + +//go:noescape +func mulVec(res, a, b *{{.ElementName}}, n uint64) + ` diff --git a/field/koalabear/element_amd64.s b/field/koalabear/element_amd64.s index c52db5c441..da45886f9c 100644 --- a/field/koalabear/element_amd64.s +++ b/field/koalabear/element_amd64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 15742393007638587690 +// We include the hash to force the Go compiler to recompile: 15784575404660199304 #include "../asm/element_1w_amd64.s" diff --git a/field/koalabear/vector_amd64.go b/field/koalabear/vector_amd64.go index 5a53382ed7..6c79b605d2 100644 --- a/field/koalabear/vector_amd64.go +++ b/field/koalabear/vector_amd64.go @@ -124,5 +124,27 @@ func (vector *Vector) InnerProduct(other Vector) (res Element) { // Mul multiplies two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Mul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + if !supportAvx512 { + // call mulVecGeneric + mulVecGeneric(*vector, a, b) + return + } + + const blockSize = 8 + mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize) + if n%blockSize != 0 { + // call mulVecGeneric on the rest + start := n - n%blockSize + mulVecGeneric((*vector)[start:], a[start:], b[start:]) + } } + +//go:noescape +func mulVec(res, a, b *Element, n uint64) diff --git a/go.mod b/go.mod index 45e3c85e9e..d1259d2cdf 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/bits-and-blooms/bitset v1.14.2 - github.com/consensys/bavard v0.1.23-0.20241208000453-1b3c9246dcd6 + github.com/consensys/bavard v0.1.23-0.20241208043834-8013eddc8088 github.com/leanovate/gopter v0.2.11 github.com/mmcloughlin/addchain v0.4.0 github.com/spf13/cobra v1.8.1 diff --git a/go.sum b/go.sum index 5adf0a8283..c3b9c1761b 100644 --- a/go.sum +++ b/go.sum @@ -71,6 +71,10 @@ github.com/consensys/bavard v0.1.23-0.20241207235803-84aa6b3d4724 h1:wBPDHYgf1Qv github.com/consensys/bavard v0.1.23-0.20241207235803-84aa6b3d4724/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/bavard v0.1.23-0.20241208000453-1b3c9246dcd6 h1:dm/VT++/p4tq8FLR/8z361AvWPD9dcp6xXebPLaEdZo= github.com/consensys/bavard v0.1.23-0.20241208000453-1b3c9246dcd6/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= +github.com/consensys/bavard v0.1.23-0.20241208033340-073352297c17 h1:7fr9/A1Nm0L67XO33mdSTPD+2prj7VmlQ63YE9ujHW8= +github.com/consensys/bavard v0.1.23-0.20241208033340-073352297c17/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= +github.com/consensys/bavard v0.1.23-0.20241208043834-8013eddc8088 h1:5fIHEbNpqWy7NWRzI8IInFCVSGyhH2BF4wpmY2XXg1k= +github.com/consensys/bavard v0.1.23-0.20241208043834-8013eddc8088/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= From f5e93c63dd93e76e13d0dbde2665e6abf7a1be2b Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sun, 8 Dec 2024 17:11:00 +0000 Subject: [PATCH 47/74] feat: clean up mul avx f31 --- field/asm/element_1w_amd64.s | 54 +++++++-------- field/babybear/element_amd64.s | 2 +- field/generator/asm/amd64/build.go | 2 +- field/generator/asm/amd64/element_vec_F31.go | 73 ++++++++------------ field/koalabear/element_amd64.s | 2 +- go.mod | 2 +- go.sum | 2 + 7 files changed, 60 insertions(+), 77 deletions(-) diff --git a/field/asm/element_1w_amd64.s b/field/asm/element_1w_amd64.s index 091d00aa18..89f853dc43 100644 --- a/field/asm/element_1w_amd64.s +++ b/field/asm/element_1w_amd64.s @@ -3,7 +3,7 @@ #include "funcdata.h" #include "go_asm.h" -// Vector operations are partially derived from Plonky3 https://github.com/Plonky3/Plonky3 +// (some) vector operations are partially derived from Plonky3 https://github.com/Plonky3/Plonky3 // addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] TEXT ·addVec(SB), NOSPLIT, $0-32 MOVD $const_q, AX @@ -18,10 +18,10 @@ loop_1: JEQ done_2 // n == 0, we are done VMOVDQU32 0(AX), Z0 VMOVDQU32 0(DX), Z1 - VPADDD Z0, Z1, Z0 - VPSUBD Z3, Z0, Z2 - VPMINUD Z0, Z2, Z1 - VMOVDQU32 Z1, 0(CX) + VPADDD Z0, Z1, Z0 // a = a + b + VPSUBD Z3, Z0, Z2 // t = a - q + VPMINUD Z0, Z2, Z1 // b = min(t, a) + VMOVDQU32 Z1, 0(CX) // res = b // increment pointers to visit next element ADDQ $64, AX @@ -47,10 +47,10 @@ loop_3: JEQ done_4 // n == 0, we are done VMOVDQU32 0(AX), Z0 VMOVDQU32 0(DX), Z1 - VPSUBD Z1, Z0, Z0 - VPADDD Z3, Z0, Z2 - VPMINUD Z0, Z2, Z1 - VMOVDQU32 Z1, 0(CX) + VPSUBD Z1, Z0, Z0 // a = a - b + VPADDD Z3, Z0, Z2 // t = a + q + VPMINUD Z0, Z2, Z1 // b = min(t, a) + VMOVDQU32 Z1, 0(CX) // res = b // increment pointers to visit next element ADDQ $64, AX @@ -73,16 +73,16 @@ TEXT ·sumVec(SB), NOSPLIT, $0-24 MOVQ t+0(FP), R15 MOVQ a+8(FP), R14 MOVQ n+16(FP), CX - VXORPS Z2, Z2, Z2 - VMOVDQA64 Z2, Z3 + VXORPS Z2, Z2, Z2 // acc1 = 0 + VMOVDQA64 Z2, Z3 // acc2 = 0 loop_5: TESTQ CX, CX JEQ done_6 // n == 0, we are done - VPMOVZXDQ 0(R14), Z0 - VPMOVZXDQ 32(R14), Z1 - VPADDQ Z0, Z2, Z2 - VPADDQ Z1, Z3, Z3 + VPMOVZXDQ 0(R14), Z0 // load 8 31bits values in a1 + VPMOVZXDQ 32(R14), Z1 // load 8 31bits values in a2 + VPADDQ Z0, Z2, Z2 // acc1 += a1 + VPADDQ Z1, Z3, Z3 // acc2 += a2 // increment pointers to visit next element ADDQ $64, R14 @@ -90,8 +90,8 @@ loop_5: JMP loop_5 done_6: - VPADDQ Z2, Z3, Z2 - VMOVDQU64 Z2, 0(R15) + VPADDQ Z2, Z3, Z2 // acc1 += acc2 + VMOVDQU64 Z2, 0(R15) // res = acc1 RET // mulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b[0...n] @@ -114,16 +114,16 @@ loop_7: JEQ done_8 // n == 0, we are done VPMOVZXDQ 0(AX), Z0 VPMOVZXDQ 0(DX), Z1 - VPMULUDQ Z0, Z1, Z2 - VPANDQ Z6, Z2, Z5 - VPMULUDQ Z5, Z4, Z5 - VPANDQ Z6, Z5, Z5 - VPMULUDQ Z5, Z3, Z5 - VPADDQ Z2, Z5, Z2 - VPSRLQ $32, Z2, Z2 - VPSUBD Z3, Z2, Z5 - VPMINUD Z2, Z5, Z2 - VPMOVQD Z2, 0(CX) + VPMULUDQ Z0, Z1, Z2 // P = a * b + VPANDQ Z6, Z2, Z5 // m = uint32(P) + VPMULUDQ Z5, Z4, Z5 // m = m * qInvNeg + VPANDQ Z6, Z5, Z5 // m = uint32(m) + VPMULUDQ Z5, Z3, Z5 // m = m * q + VPADDQ Z2, Z5, Z2 // P = P + m + VPSRLQ $32, Z2, Z2 // P = P >> 32 + VPSUBD Z3, Z2, Z5 // PL = P - q + VPMINUD Z2, Z5, Z2 // P = min(P, PL) + VPMOVQD Z2, 0(CX) // res = P // increment pointers to visit next element ADDQ $32, AX diff --git a/field/babybear/element_amd64.s b/field/babybear/element_amd64.s index da45886f9c..2bbcdd363c 100644 --- a/field/babybear/element_amd64.s +++ b/field/babybear/element_amd64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 15784575404660199304 +// We include the hash to force the Go compiler to recompile: 992380086225728104 #include "../asm/element_1w_amd64.s" diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index 7118ff9178..7d766ac8d0 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -319,7 +319,7 @@ func GenerateF31ASM(f *FFAmd64, hasVector bool) error { return nil // nothing for now. } - f.Comment("Vector operations are partially derived from Plonky3 https://github.com/Plonky3/Plonky3") + f.Comment("(some) vector operations are partially derived from Plonky3 https://github.com/Plonky3/Plonky3") f.generateAddVecF31() f.generateSubVecF31() diff --git a/field/generator/asm/amd64/element_vec_F31.go b/field/generator/asm/amd64/element_vec_F31.go index c9d006af74..9a957040fa 100644 --- a/field/generator/asm/amd64/element_vec_F31.go +++ b/field/generator/asm/amd64/element_vec_F31.go @@ -15,8 +15,6 @@ package amd64 import ( - "fmt" - "github.com/consensys/bavard/amd64" ) @@ -63,14 +61,14 @@ func (f *FFAmd64) generateAddVecF31() { // a = a + b f.VMOVDQU32(addrA.At(0), a) f.VMOVDQU32(addrB.At(0), b) - f.VPADDD(a, b, a) + f.VPADDD(a, b, a, "a = a + b") // t = a - q - f.VPSUBD(q, a, t) + f.VPSUBD(q, a, t, "t = a - q") // b = min(t, a) - f.VPMINUD(a, t, b) + f.VPMINUD(a, t, b, "b = min(t, a)") // move b to res - f.VMOVDQU32(b, addrRes.At(0)) + f.VMOVDQU32(b, addrRes.At(0), "res = b") f.Comment("increment pointers to visit next element") f.ADDQ("$64", addrA) @@ -131,16 +129,16 @@ func (f *FFAmd64) generateSubVecF31() { f.VMOVDQU32(addrA.At(0), a) f.VMOVDQU32(addrB.At(0), b) - f.VPSUBD(b, a, a) + f.VPSUBD(b, a, a, "a = a - b") // t = a + q - f.VPADDD(q, a, t) + f.VPADDD(q, a, t, "t = a + q") // b = min(t, a) - f.VPMINUD(a, t, b) + f.VPMINUD(a, t, b, "b = min(t, a)") // move b to res - f.VMOVDQU32(b, addrRes.At(0)) + f.VMOVDQU32(b, addrRes.At(0), "res = b") f.Comment("increment pointers to visit next element") f.ADDQ("$64", addrA) @@ -192,8 +190,8 @@ func (f *FFAmd64) generateSumVecF31() { f.MOVQ("n+16(FP)", len) // zeroize the accumulators - f.VXORPS(acc1, acc1, acc1) - f.VMOVDQA64(acc1, acc2) + f.VXORPS(acc1, acc1, acc1, "acc1 = 0") + f.VMOVDQA64(acc1, acc2, "acc2 = 0") f.LABEL(loop) @@ -201,11 +199,11 @@ func (f *FFAmd64) generateSumVecF31() { f.JEQ(done, "n == 0, we are done") // 1 cache line is typically 64 bytes, so we maintain 2 accumulators - f.VPMOVZXDQ(addrA.At(0), a1) - f.VPMOVZXDQ(addrA.At(4), a2) + f.VPMOVZXDQ(addrA.At(0), a1, "load 8 31bits values in a1") + f.VPMOVZXDQ(addrA.At(4), a2, "load 8 31bits values in a2") - f.VPADDQ(a1, acc1, acc1) - f.VPADDQ(a2, acc2, acc2) + f.VPADDQ(a1, acc1, acc1, "acc1 += a1") + f.VPADDQ(a2, acc2, acc2, "acc2 += a2") f.Comment("increment pointers to visit next element") f.ADDQ("$64", addrA) @@ -215,8 +213,8 @@ func (f *FFAmd64) generateSumVecF31() { f.LABEL(done) // store t into res - f.VPADDQ(acc1, acc2, acc1) - f.VMOVDQU64(acc1, addrT.At(0)) + f.VPADDQ(acc1, acc2, acc1, "acc1 += acc2") + f.VMOVDQU64(acc1, addrT.At(0), "res = acc1") f.RET() @@ -274,36 +272,19 @@ func (f *FFAmd64) generateMulVecF31() { // a = a * b f.VPMOVZXDQ(addrA.At(0), a) f.VPMOVZXDQ(addrB.At(0), b) - f.VPMULUDQ(a, b, P) - // f.VPSRLQ("$32", P, PH) - f.VPANDQ(LSW, P, PL) // low dword - // m := uint32(v) * qInvNeg --> m = PL * qInvNeg - f.VPMULUDQ(PL, qInvNeg, PL) - f.VPANDQ(LSW, PL, PL) // mod R --> keep low dword - // m*=q - f.VPMULUDQ(PL, q, PL) - // add P - f.VPADDQ(P, PL, P) - f.VPSRLQ("$32", P, P) // shift right by 32 bits - - // now we need to use min to reduce - // first sub q from P - f.VPSUBD(q, P, PL) - - // res = min(P, PL) - f.VPMINUD(P, PL, P) - - // move P to res - f.WriteLn(fmt.Sprintf("VPMOVQD %s, %s", P, addrRes.At(0))) - // f.VMOVDQU32(P, addrRes.At(0)) + f.VPMULUDQ(a, b, P, "P = a * b") + f.VPANDQ(LSW, P, PL, "m = uint32(P)") + f.VPMULUDQ(PL, qInvNeg, PL, "m = m * qInvNeg") + f.VPANDQ(LSW, PL, PL, "m = uint32(m)") + f.VPMULUDQ(PL, q, PL, "m = m * q") + f.VPADDQ(P, PL, P, "P = P + m") + f.VPSRLQ("$32", P, P, "P = P >> 32") - // now we need to montReduce + f.VPSUBD(q, P, PL, "PL = P - q") + f.VPMINUD(P, PL, P, "P = min(P, PL)") - // // a = a - b - // f.VMOVDQU32(addrA.At(0), a) - // f.VMOVDQU32(addrB.At(0), b) - - // f.VPSUBD(b, a, a) + // move P to res + f.VPMOVQD(P, addrRes.At(0), "res = P") f.Comment("increment pointers to visit next element") f.ADDQ("$32", addrA) diff --git a/field/koalabear/element_amd64.s b/field/koalabear/element_amd64.s index da45886f9c..2bbcdd363c 100644 --- a/field/koalabear/element_amd64.s +++ b/field/koalabear/element_amd64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 15784575404660199304 +// We include the hash to force the Go compiler to recompile: 992380086225728104 #include "../asm/element_1w_amd64.s" diff --git a/go.mod b/go.mod index d1259d2cdf..2c0fdd0127 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/bits-and-blooms/bitset v1.14.2 - github.com/consensys/bavard v0.1.23-0.20241208043834-8013eddc8088 + github.com/consensys/bavard v0.1.23-0.20241208050306-51bece05ad82 github.com/leanovate/gopter v0.2.11 github.com/mmcloughlin/addchain v0.4.0 github.com/spf13/cobra v1.8.1 diff --git a/go.sum b/go.sum index c3b9c1761b..220ff7c89a 100644 --- a/go.sum +++ b/go.sum @@ -75,6 +75,8 @@ github.com/consensys/bavard v0.1.23-0.20241208033340-073352297c17 h1:7fr9/A1Nm0L github.com/consensys/bavard v0.1.23-0.20241208033340-073352297c17/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/bavard v0.1.23-0.20241208043834-8013eddc8088 h1:5fIHEbNpqWy7NWRzI8IInFCVSGyhH2BF4wpmY2XXg1k= github.com/consensys/bavard v0.1.23-0.20241208043834-8013eddc8088/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= +github.com/consensys/bavard v0.1.23-0.20241208050306-51bece05ad82 h1:S2g4cxvQWO4KXQygLFtoyx4Q3lqoKRk/zJFFOYtJiJw= +github.com/consensys/bavard v0.1.23-0.20241208050306-51bece05ad82/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= From 17beb83f28b937d51084047926d91ba308a19fff Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sun, 8 Dec 2024 17:30:02 +0000 Subject: [PATCH 48/74] feat: add avx512 scalarMul vec for F31 --- field/asm/element_1w_amd64.s | 40 +++++++++ field/babybear/element_amd64.s | 2 +- field/babybear/vector_amd64.go | 48 ++++++++--- field/generator/asm/amd64/build.go | 1 + field/generator/asm/amd64/element_vec_F31.go | 81 +++++++++++++++++++ .../templates/element/vector_ops_asm.go | 49 ++++++++--- field/koalabear/element_amd64.s | 2 +- field/koalabear/vector_amd64.go | 48 ++++++++--- 8 files changed, 231 insertions(+), 40 deletions(-) diff --git a/field/asm/element_1w_amd64.s b/field/asm/element_1w_amd64.s index 89f853dc43..028c26b2a6 100644 --- a/field/asm/element_1w_amd64.s +++ b/field/asm/element_1w_amd64.s @@ -134,3 +134,43 @@ loop_7: done_8: RET + +// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b +TEXT ·scalarMulVec(SB), NOSPLIT, $0-32 + MOVD $const_q, AX + VPBROADCASTQ AX, Z3 + MOVD $const_qInvNeg, AX + VPBROADCASTQ AX, Z4 + + // Create mask for low dword in each qword + VPCMPEQB Y0, Y0, Y0 + VPMOVZXDQ Y0, Z6 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + VPBROADCASTD 0(DX), Z1 + +loop_9: + TESTQ BX, BX + JEQ done_10 // n == 0, we are done + VPMOVZXDQ 0(AX), Z0 + VPMULUDQ Z0, Z1, Z2 // P = a * b + VPANDQ Z6, Z2, Z5 // m = uint32(P) + VPMULUDQ Z5, Z4, Z5 // m = m * qInvNeg + VPANDQ Z6, Z5, Z5 // m = uint32(m) + VPMULUDQ Z5, Z3, Z5 // m = m * q + VPADDQ Z2, Z5, Z2 // P = P + m + VPSRLQ $32, Z2, Z2 // P = P >> 32 + VPSUBD Z3, Z2, Z5 // PL = P - q + VPMINUD Z2, Z5, Z2 // P = min(P, PL) + VPMOVQD Z2, 0(CX) // res = P + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_9 + +done_10: + RET diff --git a/field/babybear/element_amd64.s b/field/babybear/element_amd64.s index 2bbcdd363c..1f5f9c9243 100644 --- a/field/babybear/element_amd64.s +++ b/field/babybear/element_amd64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 992380086225728104 +// We include the hash to force the Go compiler to recompile: 2825667293017790075 #include "../asm/element_1w_amd64.s" diff --git a/field/babybear/vector_amd64.go b/field/babybear/vector_amd64.go index 650e02eee7..1243b98848 100644 --- a/field/babybear/vector_amd64.go +++ b/field/babybear/vector_amd64.go @@ -18,6 +18,21 @@ package babybear +//go:noescape +func addVec(res, a, b *Element, n uint64) + +//go:noescape +func subVec(res, a, b *Element, n uint64) + +//go:noescape +func sumVec(t *uint64, a *Element, n uint64) + +//go:noescape +func mulVec(res, a, b *Element, n uint64) + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64) + // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -43,9 +58,6 @@ func (vector *Vector) Add(a, b Vector) { } } -//go:noescape -func addVec(res, a, b *Element, n uint64) - // Sub subtracts two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Sub(a, b Vector) { @@ -71,13 +83,29 @@ func (vector *Vector) Sub(a, b Vector) { } } -//go:noescape -func subVec(res, a, b *Element, n uint64) - // ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) + if len(a) != len(*vector) { + panic("vector.ScalarMul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + if !supportAvx512 { + // call scalarMulVecGeneric + scalarMulVecGeneric(*vector, a, b) + return + } + + const blockSize = 8 + scalarMulVec(&(*vector)[0], &a[0], b, n/blockSize) + if n%blockSize != 0 { + // call scalarMulVecGeneric on the rest + start := n - n%blockSize + scalarMulVecGeneric((*vector)[start:], a[start:], b) + } } // Sum computes the sum of all elements in the vector. @@ -111,9 +139,6 @@ func (vector *Vector) Sum() (res Element) { return } -//go:noescape -func sumVec(t *uint64, a *Element, n uint64) - // InnerProduct computes the inner product of two vectors. // It panics if the vectors don't have the same length. func (vector *Vector) InnerProduct(other Vector) (res Element) { @@ -145,6 +170,3 @@ func (vector *Vector) Mul(a, b Vector) { mulVecGeneric((*vector)[start:], a[start:], b[start:]) } } - -//go:noescape -func mulVec(res, a, b *Element, n uint64) diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index 7d766ac8d0..8692de28c7 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -325,6 +325,7 @@ func GenerateF31ASM(f *FFAmd64, hasVector bool) error { f.generateSubVecF31() f.generateSumVecF31() f.generateMulVecF31() + f.generateScalarMulVecF31() return nil } diff --git a/field/generator/asm/amd64/element_vec_F31.go b/field/generator/asm/amd64/element_vec_F31.go index 9a957040fa..8f3ace4875 100644 --- a/field/generator/asm/amd64/element_vec_F31.go +++ b/field/generator/asm/amd64/element_vec_F31.go @@ -301,6 +301,87 @@ func (f *FFAmd64) generateMulVecF31() { } +// scalarMulVec res = a * b +func (f *FFAmd64) generateScalarMulVecF31() { + f.Comment("scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b") + + const argSize = 4 * 8 + stackSize := f.StackSize(f.NbWords*2+4, 0, 0) + registers := f.FnHeader("scalarMulVec", stackSize, argSize) + defer f.AssertCleanStack(stackSize, 0) + + // registers & labels we need + addrA := f.Pop(®isters) + addrB := f.Pop(®isters) + addrRes := f.Pop(®isters) + len := f.Pop(®isters) + + // AVX512 registers + a := amd64.Register("Z0") + b := amd64.Register("Z1") + P := amd64.Register("Z2") + q := amd64.Register("Z3") + qInvNeg := amd64.Register("Z4") + PL := amd64.Register("Z5") + LSW := amd64.Register("Z6") + + // load q in Z3 + f.WriteLn("MOVD $const_q, AX") + f.VPBROADCASTQ("AX", q) + f.WriteLn("MOVD $const_qInvNeg, AX") + f.VPBROADCASTQ("AX", qInvNeg) + + f.Comment("Create mask for low dword in each qword") + f.VPCMPEQB("Y0", "Y0", "Y0") + f.VPMOVZXDQ("Y0", LSW) + + loop := f.NewLabel("loop") + done := f.NewLabel("done") + + // load arguments + f.MOVQ("res+0(FP)", addrRes) + f.MOVQ("a+8(FP)", addrA) + f.MOVQ("b+16(FP)", addrB) + f.MOVQ("n+24(FP)", len) + + f.VPBROADCASTD(addrB.At(0), b) + + f.LABEL(loop) + + f.TESTQ(len, len) + f.JEQ(done, "n == 0, we are done") + + // a = a * b + f.VPMOVZXDQ(addrA.At(0), a) + + f.VPMULUDQ(a, b, P, "P = a * b") + f.VPANDQ(LSW, P, PL, "m = uint32(P)") + f.VPMULUDQ(PL, qInvNeg, PL, "m = m * qInvNeg") + f.VPANDQ(LSW, PL, PL, "m = uint32(m)") + f.VPMULUDQ(PL, q, PL, "m = m * q") + f.VPADDQ(P, PL, P, "P = P + m") + f.VPSRLQ("$32", P, P, "P = P >> 32") + + f.VPSUBD(q, P, PL, "PL = P - q") + f.VPMINUD(P, PL, P, "P = min(P, PL)") + + // move P to res + f.VPMOVQD(P, addrRes.At(0), "res = P") + + f.Comment("increment pointers to visit next element") + f.ADDQ("$32", addrA) + f.ADDQ("$32", addrRes) + f.DECQ(len, "decrement n") + f.JMP(loop) + + f.LABEL(done) + + f.RET() + + f.Push(®isters, addrA, addrB, addrRes, len) + +} + // // subVec res = a - b // // func subVec(res, a, b *{{.ElementName}}, n uint64) // func (f *FFAmd64) generateSubVecW4() { diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go index 917415e500..7683b2bae4 100644 --- a/field/generator/internal/templates/element/vector_ops_asm.go +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -147,6 +147,22 @@ func mulVec(res, a, b *{{.ElementName}}, n uint64, qInvNeg uint64) const VectorOpsArm64 = VectorOpsPureGo const VectorOpsAmd64F31 = ` + +//go:noescape +func addVec(res, a, b *{{.ElementName}}, n uint64) + +//go:noescape +func subVec(res, a, b *{{.ElementName}}, n uint64) + +//go:noescape +func sumVec(t *uint64, a *{{.ElementName}}, n uint64) + +//go:noescape +func mulVec(res, a, b *{{.ElementName}}, n uint64) + +//go:noescape +func scalarMulVec(res, a, b *{{.ElementName}}, n uint64) + // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -172,9 +188,6 @@ func (vector *Vector) Add(a, b Vector) { } } -//go:noescape -func addVec(res, a, b *{{.ElementName}}, n uint64) - // Sub subtracts two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Sub(a, b Vector) { @@ -200,13 +213,29 @@ func (vector *Vector) Sub(a, b Vector) { } } -//go:noescape -func subVec(res, a, b *{{.ElementName}}, n uint64) - // ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { - scalarMulVecGeneric(*vector, a, b) + if len(a) != len(*vector) { + panic("vector.ScalarMul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + if !supportAvx512 { + // call scalarMulVecGeneric + scalarMulVecGeneric(*vector, a, b) + return + } + + const blockSize = 8 + scalarMulVec(&(*vector)[0], &a[0], b, n/blockSize) + if n % blockSize != 0 { + // call scalarMulVecGeneric on the rest + start := n - n % blockSize + scalarMulVecGeneric((*vector)[start:], a[start:], b) + } } // Sum computes the sum of all elements in the vector. @@ -240,9 +269,6 @@ func (vector *Vector) Sum() (res {{.ElementName}}) { return } -//go:noescape -func sumVec(t *uint64, a *{{.ElementName}}, n uint64) - // InnerProduct computes the inner product of two vectors. // It panics if the vectors don't have the same length. func (vector *Vector) InnerProduct(other Vector) (res {{.ElementName}}) { @@ -275,7 +301,6 @@ func (vector *Vector) Mul(a, b Vector) { } } -//go:noescape -func mulVec(res, a, b *{{.ElementName}}, n uint64) + ` diff --git a/field/koalabear/element_amd64.s b/field/koalabear/element_amd64.s index 2bbcdd363c..1f5f9c9243 100644 --- a/field/koalabear/element_amd64.s +++ b/field/koalabear/element_amd64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 992380086225728104 +// We include the hash to force the Go compiler to recompile: 2825667293017790075 #include "../asm/element_1w_amd64.s" diff --git a/field/koalabear/vector_amd64.go b/field/koalabear/vector_amd64.go index 6c79b605d2..b773cc4539 100644 --- a/field/koalabear/vector_amd64.go +++ b/field/koalabear/vector_amd64.go @@ -18,6 +18,21 @@ package koalabear +//go:noescape +func addVec(res, a, b *Element, n uint64) + +//go:noescape +func subVec(res, a, b *Element, n uint64) + +//go:noescape +func sumVec(t *uint64, a *Element, n uint64) + +//go:noescape +func mulVec(res, a, b *Element, n uint64) + +//go:noescape +func scalarMulVec(res, a, b *Element, n uint64) + // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -43,9 +58,6 @@ func (vector *Vector) Add(a, b Vector) { } } -//go:noescape -func addVec(res, a, b *Element, n uint64) - // Sub subtracts two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Sub(a, b Vector) { @@ -71,13 +83,29 @@ func (vector *Vector) Sub(a, b Vector) { } } -//go:noescape -func subVec(res, a, b *Element, n uint64) - // ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) + if len(a) != len(*vector) { + panic("vector.ScalarMul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + if !supportAvx512 { + // call scalarMulVecGeneric + scalarMulVecGeneric(*vector, a, b) + return + } + + const blockSize = 8 + scalarMulVec(&(*vector)[0], &a[0], b, n/blockSize) + if n%blockSize != 0 { + // call scalarMulVecGeneric on the rest + start := n - n%blockSize + scalarMulVecGeneric((*vector)[start:], a[start:], b) + } } // Sum computes the sum of all elements in the vector. @@ -111,9 +139,6 @@ func (vector *Vector) Sum() (res Element) { return } -//go:noescape -func sumVec(t *uint64, a *Element, n uint64) - // InnerProduct computes the inner product of two vectors. // It panics if the vectors don't have the same length. func (vector *Vector) InnerProduct(other Vector) (res Element) { @@ -145,6 +170,3 @@ func (vector *Vector) Mul(a, b Vector) { mulVecGeneric((*vector)[start:], a[start:], b[start:]) } } - -//go:noescape -func mulVec(res, a, b *Element, n uint64) From c1b06b41f1f999ee16c9b419d1e5a6001a4a1779 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sun, 8 Dec 2024 17:49:25 +0000 Subject: [PATCH 49/74] feat: add innerProdVec avx512 for f31 --- field/asm/element_1w_amd64.s | 48 +++++++++- field/babybear/arith.go | 60 ------------- field/babybear/asm_adx.go | 26 ------ field/babybear/asm_noadx.go | 27 ------ field/babybear/element_amd64.go | 66 -------------- field/babybear/element_amd64.s | 2 +- field/babybear/vector_amd64.go | 33 ++++++- field/generator/asm/amd64/build.go | 3 +- field/generator/asm/amd64/element_vec_F31.go | 87 +++++++++++++++++++ .../templates/element/vector_ops_asm.go | 33 ++++++- field/koalabear/element_amd64.s | 2 +- field/koalabear/vector_amd64.go | 33 ++++++- 12 files changed, 233 insertions(+), 187 deletions(-) delete mode 100644 field/babybear/arith.go delete mode 100644 field/babybear/asm_adx.go delete mode 100644 field/babybear/asm_noadx.go delete mode 100644 field/babybear/element_amd64.go diff --git a/field/asm/element_1w_amd64.s b/field/asm/element_1w_amd64.s index 028c26b2a6..655e57cddf 100644 --- a/field/asm/element_1w_amd64.s +++ b/field/asm/element_1w_amd64.s @@ -3,7 +3,6 @@ #include "funcdata.h" #include "go_asm.h" -// (some) vector operations are partially derived from Plonky3 https://github.com/Plonky3/Plonky3 // addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] TEXT ·addVec(SB), NOSPLIT, $0-32 MOVD $const_q, AX @@ -174,3 +173,50 @@ loop_9: done_10: RET + +// innerProdVec(t *uint64, a,b *[]uint32, n uint64) res = sum(a[0...n] * b[0...n]) +TEXT ·innerProdVec(SB), NOSPLIT, $0-32 + + // Similar to mulVec; we do most of the montgomery multiplication but don't do + // the final reduction. We accumulate the result like in sumVec and let the caller + // reduce mod q. + + MOVD $const_q, AX + VPBROADCASTQ AX, Z3 + MOVD $const_qInvNeg, AX + VPBROADCASTQ AX, Z4 + + // Create mask for low dword in each qword + VPCMPEQB Y0, Y0, Y0 + VPMOVZXDQ Y0, Z6 + VXORPS Z2, Z2, Z2 // acc = 0 + MOVQ t+0(FP), CX + MOVQ a+8(FP), R14 + MOVQ b+16(FP), R15 + MOVQ n+24(FP), BX + +loop_11: + TESTQ BX, BX + JEQ done_12 // n == 0, we are done + VPMOVZXDQ 0(R14), Z0 + VPMOVZXDQ 0(R15), Z1 + VPMULUDQ Z0, Z1, Z7 // P = a * b + VPANDQ Z6, Z7, Z5 // m = uint32(P) + VPMULUDQ Z5, Z4, Z5 // m = m * qInvNeg + VPANDQ Z6, Z5, Z5 // m = uint32(m) + VPMULUDQ Z5, Z3, Z5 // m = m * q + VPADDQ Z7, Z5, Z7 // P = P + m + VPSRLQ $32, Z7, Z7 // P = P >> 32 + + // accumulate P into acc, P is in [0, 2q] on 32bits max + VPADDQ Z7, Z2, Z2 // acc += P + + // increment pointers to visit next element + ADDQ $32, R14 + ADDQ $32, R15 + DECQ BX // decrement n + JMP loop_11 + +done_12: + VMOVDQU64 Z2, 0(CX) // res = acc + RET diff --git a/field/babybear/arith.go b/field/babybear/arith.go deleted file mode 100644 index 3dfd7e5ffe..0000000000 --- a/field/babybear/arith.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package babybear - -import ( - "math/bits" -) - -// madd0 hi = a*b + c (discards lo bits) -func madd0(a, b, c uint64) (hi uint64) { - var carry, lo uint64 - hi, lo = bits.Mul64(a, b) - _, carry = bits.Add64(lo, c, 0) - hi, _ = bits.Add64(hi, 0, carry) - return -} - -// madd1 hi, lo = a*b + c -func madd1(a, b, c uint64) (hi uint64, lo uint64) { - var carry uint64 - hi, lo = bits.Mul64(a, b) - lo, carry = bits.Add64(lo, c, 0) - hi, _ = bits.Add64(hi, 0, carry) - return -} - -// madd2 hi, lo = a*b + c + d -func madd2(a, b, c, d uint64) (hi uint64, lo uint64) { - var carry uint64 - hi, lo = bits.Mul64(a, b) - c, carry = bits.Add64(c, d, 0) - hi, _ = bits.Add64(hi, 0, carry) - lo, carry = bits.Add64(lo, c, 0) - hi, _ = bits.Add64(hi, 0, carry) - return -} - -func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { - var carry uint64 - hi, lo = bits.Mul64(a, b) - c, carry = bits.Add64(c, d, 0) - hi, _ = bits.Add64(hi, 0, carry) - lo, carry = bits.Add64(lo, c, 0) - hi, _ = bits.Add64(hi, e, carry) - return -} diff --git a/field/babybear/asm_adx.go b/field/babybear/asm_adx.go deleted file mode 100644 index 33faad5ec2..0000000000 --- a/field/babybear/asm_adx.go +++ /dev/null @@ -1,26 +0,0 @@ -//go:build !noadx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package babybear - -import "golang.org/x/sys/cpu" - -var ( - supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 - _ = supportAdx -) diff --git a/field/babybear/asm_noadx.go b/field/babybear/asm_noadx.go deleted file mode 100644 index c01b8ba5dc..0000000000 --- a/field/babybear/asm_noadx.go +++ /dev/null @@ -1,27 +0,0 @@ -//go:build noadx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package babybear - -// note: this is needed for test purposes, as dynamically changing supportAdx doesn't flag -// certain errors (like fatal error: missing stackmap) -// this ensures we test all asm path. -var ( - supportAdx = false - _ = supportAdx -) diff --git a/field/babybear/element_amd64.go b/field/babybear/element_amd64.go deleted file mode 100644 index 9a0c54891f..0000000000 --- a/field/babybear/element_amd64.go +++ /dev/null @@ -1,66 +0,0 @@ -//go:build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package babybear - -//go:noescape -func MulBy3(x *Element) - -//go:noescape -func MulBy5(x *Element) - -//go:noescape -func MulBy13(x *Element) - -//go:noescape -func mul(res, x, y *Element) - -//go:noescape -func fromMont(res *Element) - -//go:noescape -func reduce(res *Element) - -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -// -//go:noescape -func Butterfly(a, b *Element) - -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for doc. - mul(z, x, x) - return z -} diff --git a/field/babybear/element_amd64.s b/field/babybear/element_amd64.s index 1f5f9c9243..da9b9a9d73 100644 --- a/field/babybear/element_amd64.s +++ b/field/babybear/element_amd64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 2825667293017790075 +// We include the hash to force the Go compiler to recompile: 4604324921634291592 #include "../asm/element_1w_amd64.s" diff --git a/field/babybear/vector_amd64.go b/field/babybear/vector_amd64.go index 1243b98848..2b437b6811 100644 --- a/field/babybear/vector_amd64.go +++ b/field/babybear/vector_amd64.go @@ -33,6 +33,9 @@ func mulVec(res, a, b *Element, n uint64) //go:noescape func scalarMulVec(res, a, b *Element, n uint64) +//go:noescape +func innerProdVec(t *uint64, a, b *Element, n uint64) + // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -142,7 +145,35 @@ func (vector *Vector) Sum() (res Element) { // InnerProduct computes the inner product of two vectors. // It panics if the vectors don't have the same length. func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) + n := uint64(len(*vector)) + if n == 0 { + return + } + if n != uint64(len(other)) { + panic("vector.InnerProduct: vectors don't have the same length") + } + if !supportAvx512 { + // call innerProductVecGeneric + innerProductVecGeneric(&res, *vector, other) + return + } + + const blockSize = 8 + var t [8]uint64 // stores the accumulators (not reduced mod q) + innerProdVec(&t[0], &(*vector)[0], &other[0], n/blockSize) + // we reduce the accumulators mod q and add to res + var v Element + for i := 0; i < 8; i++ { + t[i] %= q + v[0] = uint32(t[i]) + res.Add(&res, &v) + } + if n%blockSize != 0 { + // call innerProductVecGeneric on the rest + start := n - n%blockSize + innerProductVecGeneric(&res, (*vector)[start:], other[start:]) + } + return } diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index 8692de28c7..30d482236b 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -319,13 +319,12 @@ func GenerateF31ASM(f *FFAmd64, hasVector bool) error { return nil // nothing for now. } - f.Comment("(some) vector operations are partially derived from Plonky3 https://github.com/Plonky3/Plonky3") - f.generateAddVecF31() f.generateSubVecF31() f.generateSumVecF31() f.generateMulVecF31() f.generateScalarMulVecF31() + f.generateInnerProdVecF31() return nil } diff --git a/field/generator/asm/amd64/element_vec_F31.go b/field/generator/asm/amd64/element_vec_F31.go index 8f3ace4875..538b971940 100644 --- a/field/generator/asm/amd64/element_vec_F31.go +++ b/field/generator/asm/amd64/element_vec_F31.go @@ -379,7 +379,94 @@ func (f *FFAmd64) generateScalarMulVecF31() { f.RET() f.Push(®isters, addrA, addrB, addrRes, len) +} + +// innerProdVec res = sum(a * b) +func (f *FFAmd64) generateInnerProdVecF31() { + f.Comment("innerProdVec(t *uint64, a,b *[]uint32, n uint64) res = sum(a[0...n] * b[0...n])") + + const argSize = 4 * 8 + stackSize := f.StackSize(f.NbWords*4+2, 0, 0) + registers := f.FnHeader("innerProdVec", stackSize, argSize, amd64.DX, amd64.AX) + defer f.AssertCleanStack(stackSize, 0) + + f.WriteLn(` + // Similar to mulVec; we do most of the montgomery multiplication but don't do + // the final reduction. We accumulate the result like in sumVec and let the caller + // reduce mod q. + `) + + // registers & labels we need + addrA := f.Pop(®isters) + addrB := f.Pop(®isters) + addrT := f.Pop(®isters) + len := f.Pop(®isters) + + // AVX512 registers + a := amd64.Register("Z0") + b := amd64.Register("Z1") + acc := amd64.Register("Z2") + q := amd64.Register("Z3") + qInvNeg := amd64.Register("Z4") + PL := amd64.Register("Z5") + LSW := amd64.Register("Z6") + P := amd64.Register("Z7") + + loop := f.NewLabel("loop") + done := f.NewLabel("done") + + f.WriteLn("MOVD $const_q, AX") + f.VPBROADCASTQ("AX", q) + f.WriteLn("MOVD $const_qInvNeg, AX") + f.VPBROADCASTQ("AX", qInvNeg) + + f.Comment("Create mask for low dword in each qword") + f.VPCMPEQB("Y0", "Y0", "Y0") + f.VPMOVZXDQ("Y0", LSW) + + // zeroize the accumulators + f.VXORPS(acc, acc, acc, "acc = 0") + // load arguments + f.MOVQ("t+0(FP)", addrT) + f.MOVQ("a+8(FP)", addrA) + f.MOVQ("b+16(FP)", addrB) + f.MOVQ("n+24(FP)", len) + + f.LABEL(loop) + + f.TESTQ(len, len) + f.JEQ(done, "n == 0, we are done") + + f.VPMOVZXDQ(addrA.At(0), a) + f.VPMOVZXDQ(addrB.At(0), b) + + f.VPMULUDQ(a, b, P, "P = a * b") + f.VPANDQ(LSW, P, PL, "m = uint32(P)") + f.VPMULUDQ(PL, qInvNeg, PL, "m = m * qInvNeg") + f.VPANDQ(LSW, PL, PL, "m = uint32(m)") + f.VPMULUDQ(PL, q, PL, "m = m * q") + f.VPADDQ(P, PL, P, "P = P + m") + f.VPSRLQ("$32", P, P, "P = P >> 32") + + // TODO @gbotrel comment on the bound and ensure caller can't trigger overflow in accumulator. + f.Comment("accumulate P into acc, P is in [0, 2q] on 32bits max") + f.VPADDQ(P, acc, acc, "acc += P") + + f.Comment("increment pointers to visit next element") + f.ADDQ("$32", addrA) + f.ADDQ("$32", addrB) + f.DECQ(len, "decrement n") + f.JMP(loop) + + f.LABEL(done) + + // store t into res + f.VMOVDQU64(acc, addrT.At(0), "res = acc") + + f.RET() + + f.Push(®isters, addrA, addrT, len) } // // subVec res = a - b diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go index 7683b2bae4..aa00830dfd 100644 --- a/field/generator/internal/templates/element/vector_ops_asm.go +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -163,6 +163,9 @@ func mulVec(res, a, b *{{.ElementName}}, n uint64) //go:noescape func scalarMulVec(res, a, b *{{.ElementName}}, n uint64) +//go:noescape +func innerProdVec(t *uint64, a, b *{{.ElementName}}, n uint64) + // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -272,7 +275,35 @@ func (vector *Vector) Sum() (res {{.ElementName}}) { // InnerProduct computes the inner product of two vectors. // It panics if the vectors don't have the same length. func (vector *Vector) InnerProduct(other Vector) (res {{.ElementName}}) { - innerProductVecGeneric(&res, *vector, other) + n := uint64(len(*vector)) + if n == 0 { + return + } + if n != uint64(len(other)) { + panic("vector.InnerProduct: vectors don't have the same length") + } + if !supportAvx512 { + // call innerProductVecGeneric + innerProductVecGeneric(&res, *vector, other) + return + } + + const blockSize = 8 + var t [8]uint64 // stores the accumulators (not reduced mod q) + innerProdVec(&t[0], &(*vector)[0], &other[0], n/blockSize) + // we reduce the accumulators mod q and add to res + var v {{.ElementName}} + for i := 0; i < 8; i++ { + t[i] %= q + v[0] = uint32(t[i]) + res.Add(&res, &v) + } + if n % blockSize != 0 { + // call innerProductVecGeneric on the rest + start := n - n % blockSize + innerProductVecGeneric(&res, (*vector)[start:], other[start:]) + } + return } diff --git a/field/koalabear/element_amd64.s b/field/koalabear/element_amd64.s index 1f5f9c9243..da9b9a9d73 100644 --- a/field/koalabear/element_amd64.s +++ b/field/koalabear/element_amd64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 2825667293017790075 +// We include the hash to force the Go compiler to recompile: 4604324921634291592 #include "../asm/element_1w_amd64.s" diff --git a/field/koalabear/vector_amd64.go b/field/koalabear/vector_amd64.go index b773cc4539..1068935c7f 100644 --- a/field/koalabear/vector_amd64.go +++ b/field/koalabear/vector_amd64.go @@ -33,6 +33,9 @@ func mulVec(res, a, b *Element, n uint64) //go:noescape func scalarMulVec(res, a, b *Element, n uint64) +//go:noescape +func innerProdVec(t *uint64, a, b *Element, n uint64) + // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -142,7 +145,35 @@ func (vector *Vector) Sum() (res Element) { // InnerProduct computes the inner product of two vectors. // It panics if the vectors don't have the same length. func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) + n := uint64(len(*vector)) + if n == 0 { + return + } + if n != uint64(len(other)) { + panic("vector.InnerProduct: vectors don't have the same length") + } + if !supportAvx512 { + // call innerProductVecGeneric + innerProductVecGeneric(&res, *vector, other) + return + } + + const blockSize = 8 + var t [8]uint64 // stores the accumulators (not reduced mod q) + innerProdVec(&t[0], &(*vector)[0], &other[0], n/blockSize) + // we reduce the accumulators mod q and add to res + var v Element + for i := 0; i < 8; i++ { + t[i] %= q + v[0] = uint32(t[i]) + res.Add(&res, &v) + } + if n%blockSize != 0 { + // call innerProductVecGeneric on the rest + start := n - n%blockSize + innerProductVecGeneric(&res, (*vector)[start:], other[start:]) + } + return } From 9af3aed89f0547724099772af5e35fdc988ae9db Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sun, 8 Dec 2024 15:28:59 -0600 Subject: [PATCH 50/74] style: code cleaning --- field/asm/element_1w_amd64.s | 10 +++++----- field/babybear/element_amd64.s | 2 +- field/generator/asm/amd64/element_vec_F31.go | 12 ++++++------ field/generator/config/field_config.go | 20 -------------------- field/koalabear/element_amd64.s | 2 +- 5 files changed, 13 insertions(+), 33 deletions(-) diff --git a/field/asm/element_1w_amd64.s b/field/asm/element_1w_amd64.s index 655e57cddf..f8f134291f 100644 --- a/field/asm/element_1w_amd64.s +++ b/field/asm/element_1w_amd64.s @@ -62,13 +62,13 @@ done_4: RET // sumVec(res *uint64, a *[]uint32, n uint64) res = sum(a[0...n]) +TEXT ·sumVec(SB), NOSPLIT, $0-24 -// We are load 8 31bits values at a time and accumulate them into an accumulator of -// 8 quadwords (64bits). The caller then needs to reduce the result mod q. -// We can safely accumulate ~2**33 31bits values into a single accumulator. -// That gives us a maximum of 2**33 * 8 = 2**36 31bits values to sum safely. + // We load 8 31bits values at a time and accumulate them into an accumulator of + // 8 quadwords (64bits). The caller then needs to reduce the result mod q. + // We can safely accumulate ~2**33 31bits values into a single accumulator. + // That gives us a maximum of 2**33 * 8 = 2**36 31bits values to sum safely. -TEXT ·sumVec(SB), NOSPLIT, $0-24 MOVQ t+0(FP), R15 MOVQ a+8(FP), R14 MOVQ n+16(FP), CX diff --git a/field/babybear/element_amd64.s b/field/babybear/element_amd64.s index da9b9a9d73..cde280d193 100644 --- a/field/babybear/element_amd64.s +++ b/field/babybear/element_amd64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 4604324921634291592 +// We include the hash to force the Go compiler to recompile: 11172894854395138580 #include "../asm/element_1w_amd64.s" diff --git a/field/generator/asm/amd64/element_vec_F31.go b/field/generator/asm/amd64/element_vec_F31.go index 538b971940..174e842197 100644 --- a/field/generator/asm/amd64/element_vec_F31.go +++ b/field/generator/asm/amd64/element_vec_F31.go @@ -158,18 +158,18 @@ func (f *FFAmd64) generateSubVecF31() { // sumVec res = sum(a[0...n]) func (f *FFAmd64) generateSumVecF31() { f.Comment("sumVec(res *uint64, a *[]uint32, n uint64) res = sum(a[0...n])") + const argSize = 3 * 8 + stackSize := f.StackSize(f.NbWords*3+2, 0, 0) + registers := f.FnHeader("sumVec", stackSize, argSize, amd64.DX, amd64.AX) + defer f.AssertCleanStack(stackSize, 0) + f.WriteLn(` - // We are load 8 31bits values at a time and accumulate them into an accumulator of + // We load 8 31bits values at a time and accumulate them into an accumulator of // 8 quadwords (64bits). The caller then needs to reduce the result mod q. // We can safely accumulate ~2**33 31bits values into a single accumulator. // That gives us a maximum of 2**33 * 8 = 2**36 31bits values to sum safely. `) - const argSize = 3 * 8 - stackSize := f.StackSize(f.NbWords*3+2, 0, 0) - registers := f.FnHeader("sumVec", stackSize, argSize, amd64.DX, amd64.AX) - defer f.AssertCleanStack(stackSize, 0) - // registers & labels we need addrA := f.Pop(®isters) addrT := f.Pop(®isters) diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index ec0a7139d3..26ee5fe5e7 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -313,26 +313,6 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) F.Mu = muSlice[0] } - // We define MONTY_MU = PRIME^-1 (mod 2^MONTY_BITS). This is different from the usual convention - // (MONTY_MU = -PRIME^-1 (mod 2^MONTY_BITS)) but it avoids a carry. - // 2164260865 - if F.F31 { - // _mu := big.NewInt(0) - // _mu.Set(&bModulus) - // _mu.Neg(_mu) - // _mu.ModInverse(_mu, big.NewInt(1<<31)) - // muSlice := toUint64Slice(_mu, F.NbWords) - // F.Mu = muSlice[0] - _r := big.NewInt(1) - _r.Lsh(_r, uint(F.NbWords)*radix) - _rInv := big.NewInt(1) - _qInv := big.NewInt(0) - extendedEuclideanAlgo(_r, &bModulus, _rInv, _qInv) - _qInv.Neg(_qInv) - _qInv.Mod(_qInv, _r) - F.Mu = toUint64Slice(_qInv, F.NbWords)[0] - } - return F, nil } diff --git a/field/koalabear/element_amd64.s b/field/koalabear/element_amd64.s index da9b9a9d73..cde280d193 100644 --- a/field/koalabear/element_amd64.s +++ b/field/koalabear/element_amd64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 4604324921634291592 +// We include the hash to force the Go compiler to recompile: 11172894854395138580 #include "../asm/element_1w_amd64.s" From 5514b84f7e7bdd8b279b3abaac159af6f3d74ec2 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sun, 8 Dec 2024 21:48:43 +0000 Subject: [PATCH 51/74] style: more cleaning --- field/generator/asm/amd64/element_vec_F31.go | 79 +------------------- field/generator/asm/arm64/build.go | 42 ----------- 2 files changed, 2 insertions(+), 119 deletions(-) diff --git a/field/generator/asm/amd64/element_vec_F31.go b/field/generator/asm/amd64/element_vec_F31.go index 174e842197..c8bde3e005 100644 --- a/field/generator/asm/amd64/element_vec_F31.go +++ b/field/generator/asm/amd64/element_vec_F31.go @@ -449,7 +449,8 @@ func (f *FFAmd64) generateInnerProdVecF31() { f.VPADDQ(P, PL, P, "P = P + m") f.VPSRLQ("$32", P, P, "P = P >> 32") - // TODO @gbotrel comment on the bound and ensure caller can't trigger overflow in accumulator. + // we can accumulate ~2**32 32bits values into a single accumulator without overflow; + // that gives us a maximum of 2**32 * 8 = 2**35 32bits values to sum safely. f.Comment("accumulate P into acc, P is in [0, 2q] on 32bits max") f.VPADDQ(P, acc, acc, "acc += P") @@ -468,79 +469,3 @@ func (f *FFAmd64) generateInnerProdVecF31() { f.Push(®isters, addrA, addrT, len) } - -// // subVec res = a - b -// // func subVec(res, a, b *{{.ElementName}}, n uint64) -// func (f *FFAmd64) generateSubVecW4() { -// f.Comment("subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n]") - -// const argSize = 4 * 8 -// stackSize := f.StackSize(f.NbWords*2+5, 0, 0) -// registers := f.FnHeader("subVec", stackSize, argSize) -// defer f.AssertCleanStack(stackSize, 0) - -// // registers -// addrA := f.Pop(®isters) -// addrB := f.Pop(®isters) -// addrRes := f.Pop(®isters) -// len := f.Pop(®isters) -// zero := f.Pop(®isters) - -// a := f.PopN(®isters) -// q := f.PopN(®isters) - -// loop := f.NewLabel("loop") -// done := f.NewLabel("done") - -// // load arguments -// f.MOVQ("res+0(FP)", addrRes) -// f.MOVQ("a+8(FP)", addrA) -// f.MOVQ("b+16(FP)", addrB) -// f.MOVQ("n+24(FP)", len) - -// f.XORQ(zero, zero) - -// f.LABEL(loop) - -// f.TESTQ(len, len) -// f.JEQ(done, "n == 0, we are done") - -// // a = a - b -// f.LabelRegisters("a", a...) -// f.Mov(addrA, a) -// f.Sub(addrB, a) -// f.WriteLn(fmt.Sprintf("PREFETCHT0 2048(%[1]s)", addrA)) -// f.WriteLn(fmt.Sprintf("PREFETCHT0 2048(%[1]s)", addrB)) - -// // reduce a -// f.Comment("reduce (a-b) mod q") -// f.LabelRegisters("q", q...) -// for i := 0; i < f.NbWords; i++ { -// f.MOVQ(fmt.Sprintf("$const_q%d", i), q[i]) -// } -// for i := 0; i < f.NbWords; i++ { -// f.CMOVQCC(zero, q[i]) -// } -// // add registers (q or 0) to a, and set to result -// f.Comment("add registers (q or 0) to a, and set to result") -// f.Add(q, a) - -// // save a into res -// f.Mov(a, addrRes) - -// f.Comment("increment pointers to visit next element") -// f.ADDQ("$32", addrA) -// f.ADDQ("$32", addrB) -// f.ADDQ("$32", addrRes) -// f.DECQ(len, "decrement n") -// f.JMP(loop) - -// f.LABEL(done) - -// f.RET() - -// f.Push(®isters, a...) -// f.Push(®isters, q...) -// f.Push(®isters, addrA, addrB, addrRes, len, zero) - -// } diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index 08bc52fb1c..8b976c09d7 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -116,44 +116,6 @@ func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { return nil } -// // Generate generates assembly code for the base field provided to goff -// // see internal/templates/ops* -// func Generate(w io.Writer, F *field.Field) error { -// f := NewFFArm64(w, F) -// f.WriteLn(bavard.Apache2Header("ConsenSys Software Inc.", 2020)) - -// f.WriteLn("#include \"textflag.h\"") -// f.WriteLn("#include \"funcdata.h\"\n") - -// f.generateStoreVector() - -// // add -// //TODO: It requires field size < 960 -// f.generateAdd() - -// // sub -// f.generateSub() - -// // double -// f.generateDouble() - -// // neg -// f.generateNeg() -// /* -// // reduce -// f.generateReduce() - -// // mul by constants -// f.generateMulBy3() -// f.generateMulBy5() -// f.generateMulBy13() - -// // fft butterflies -// f.generateButterfly()*/ - -// return nil -// } - func (f *FFArm64) DefineFn(name string) (fn defineFn, err error) { fn, ok := f.mDefines[name] if !ok { @@ -290,7 +252,3 @@ func (f *FFArm64) qAt(index int) string { func (f *FFArm64) qInv0() string { return "$const_qInvNeg" } - -func (f *FFArm64) qi(i int) string { - return fmt.Sprintf("$const_q%d", i) -} From ccb7ad1ea2f50ba07f37ca245051d6b6de7d9e65 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sun, 8 Dec 2024 21:54:05 +0000 Subject: [PATCH 52/74] refactor: give nb bits to asm generation --- field/generator/asm/amd64/build.go | 9 +++++++-- field/generator/asm/arm64/build.go | 2 +- field/generator/generator.go | 8 ++++---- field/generator/generator_test.go | 12 ++++++------ field/goff/cmd/root.go | 4 ++-- field/internal/main.go | 2 +- internal/generator/main.go | 20 ++++++++++---------- 7 files changed, 31 insertions(+), 26 deletions(-) diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index 30d482236b..0af3fda97f 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -266,7 +266,7 @@ func GenerateFieldWrapper(w io.Writer, F *config.FieldConfig, asmDirBuildPath, a // GenerateCommonASM generates assembly code for the base field provided to goff // see internal/templates/ops* -func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { +func GenerateCommonASM(w io.Writer, nbWords, nbBits int, hasVector bool) error { f := NewFFAmd64(w, nbWords) f.Comment("Code generated by gnark-crypto/generator. DO NOT EDIT.") @@ -276,7 +276,12 @@ func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { f.WriteLn("") if nbWords == 1 { - return GenerateF31ASM(f, hasVector) + if nbBits == 31 { + return GenerateF31ASM(f, hasVector) + } else { + panic("not implemented") + } + } f.GenerateReduceDefine() diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index 8b976c09d7..20519c6019 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -94,7 +94,7 @@ func GenerateFieldWrapper(w io.Writer, F *config.FieldConfig, asmDirBuildPath, a // GenerateCommonASM generates assembly code for the base field provided to goff // see internal/templates/ops* -func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { +func GenerateCommonASM(w io.Writer, nbWords, nbBits int, hasVector bool) error { f := NewFFArm64(w, nbWords) f.Comment("Code generated by gnark-crypto/generator. DO NOT EDIT.") diff --git a/field/generator/generator.go b/field/generator/generator.go index 174c205c27..a3b5513949 100644 --- a/field/generator/generator.go +++ b/field/generator/generator.go @@ -243,7 +243,7 @@ func shorten(input string) string { return input } -func GenerateARM64(nbWords int, asmDir string, hasVector bool) error { +func GenerateARM64(nbWords, nbBits int, asmDir string, hasVector bool) error { os.MkdirAll(asmDir, 0755) pathSrc := filepath.Join(asmDir, fmt.Sprintf(arm64.ElementASMFileName, nbWords)) @@ -253,7 +253,7 @@ func GenerateARM64(nbWords int, asmDir string, hasVector bool) error { return err } - if err := arm64.GenerateCommonASM(f, nbWords, hasVector); err != nil { + if err := arm64.GenerateCommonASM(f, nbWords, nbBits, hasVector); err != nil { _ = f.Close() return err } @@ -271,7 +271,7 @@ func GenerateARM64(nbWords int, asmDir string, hasVector bool) error { return nil } -func GenerateAMD64(nbWords int, asmDir string, hasVector bool) error { +func GenerateAMD64(nbWords, nbBits int, asmDir string, hasVector bool) error { os.MkdirAll(asmDir, 0755) pathSrc := filepath.Join(asmDir, fmt.Sprintf(amd64.ElementASMFileName, nbWords)) @@ -281,7 +281,7 @@ func GenerateAMD64(nbWords int, asmDir string, hasVector bool) error { return err } - if err := amd64.GenerateCommonASM(f, nbWords, hasVector); err != nil { + if err := amd64.GenerateCommonASM(f, nbWords, nbBits, hasVector); err != nil { _ = f.Close() return err } diff --git a/field/generator/generator_test.go b/field/generator/generator_test.go index ee5b5fafc0..91a896b0b4 100644 --- a/field/generator/generator_test.go +++ b/field/generator/generator_test.go @@ -79,13 +79,13 @@ func TestIntegration(t *testing.T) { moduli["e_nocarry_edge_0127"] = "170141183460469231731687303715884105727" moduli["e_nocarry_edge_1279"] = "10407932194664399081925240327364085538615262247266704805319112350403608059673360298012239441732324184842421613954281007791383566248323464908139906605677320762924129509389220345773183349661583550472959420547689811211693677147548478866962501384438260291732348885311160828538416585028255604666224831890918801847068222203140521026698435488732958028878050869736186900714720710555703168729087" - assert.NoError(GenerateAMD64(2, asmDir, false)) - assert.NoError(GenerateAMD64(3, asmDir, false)) - assert.NoError(GenerateAMD64(7, asmDir, false)) - assert.NoError(GenerateAMD64(8, asmDir, false)) + assert.NoError(GenerateAMD64(2, 0, asmDir, false)) + assert.NoError(GenerateAMD64(3, 0, asmDir, false)) + assert.NoError(GenerateAMD64(7, 0, asmDir, false)) + assert.NoError(GenerateAMD64(8, 0, asmDir, false)) - assert.NoError(GenerateARM64(2, asmDir, false)) - assert.NoError(GenerateARM64(8, asmDir, false)) + assert.NoError(GenerateARM64(2, 0, asmDir, false)) + assert.NoError(GenerateARM64(8, 0, asmDir, false)) for elementName, modulus := range moduli { var fIntegration *field.FieldConfig diff --git a/field/goff/cmd/root.go b/field/goff/cmd/root.go index c2ae86809e..b70a048932 100644 --- a/field/goff/cmd/root.go +++ b/field/goff/cmd/root.go @@ -74,13 +74,13 @@ func cmdGenerate(cmd *cobra.Command, args []string) { asmDir := filepath.Join(fOutputDir, "asm") if F.GenerateOpsAMD64 { - if err := generator.GenerateAMD64(F.NbWords, asmDir, F.GenerateVectorOpsAMD64); err != nil { + if err := generator.GenerateAMD64(F.NbWords, F.NbBits, asmDir, F.GenerateVectorOpsAMD64); err != nil { fmt.Printf("\n%s\n", err.Error()) os.Exit(-1) } } if F.GenerateOpsARM64 { - if err := generator.GenerateARM64(F.NbWords, asmDir, F.GenerateVectorOpsARM64); err != nil { + if err := generator.GenerateARM64(F.NbWords, F.NbBits, asmDir, F.GenerateVectorOpsARM64); err != nil { fmt.Printf("\n%s\n", err.Error()) os.Exit(-1) } diff --git a/field/internal/main.go b/field/internal/main.go index ce896f7fd3..29d84091c1 100644 --- a/field/internal/main.go +++ b/field/internal/main.go @@ -26,7 +26,7 @@ func main() { // generate assembly asmDir := filepath.Join("..", "asm") asmDirIncludePath := filepath.Join("..", "asm") - if err := generator.GenerateAMD64(1, asmDir, true); err != nil { + if err := generator.GenerateAMD64(1, 31, asmDir, true); err != nil { panic(err) } diff --git a/internal/generator/main.go b/internal/generator/main.go index c17605a414..dadbcab04f 100644 --- a/internal/generator/main.go +++ b/internal/generator/main.go @@ -52,16 +52,16 @@ func main() { asmDirIncludePath := filepath.Join(baseDir, "..", "field", "asm") // generate common assembly files depending on field number of words - assertNoError(generator.GenerateAMD64(4, asmDirBuildPath, true)) - assertNoError(generator.GenerateAMD64(5, asmDirBuildPath, false)) - assertNoError(generator.GenerateAMD64(6, asmDirBuildPath, false)) - assertNoError(generator.GenerateAMD64(10, asmDirBuildPath, false)) - assertNoError(generator.GenerateAMD64(12, asmDirBuildPath, false)) - - assertNoError(generator.GenerateARM64(4, asmDirBuildPath, false)) - assertNoError(generator.GenerateARM64(6, asmDirBuildPath, false)) - assertNoError(generator.GenerateARM64(10, asmDirBuildPath, false)) - assertNoError(generator.GenerateARM64(12, asmDirBuildPath, false)) + assertNoError(generator.GenerateAMD64(4, 0, asmDirBuildPath, true)) + assertNoError(generator.GenerateAMD64(5, 0, asmDirBuildPath, false)) + assertNoError(generator.GenerateAMD64(6, 0, asmDirBuildPath, false)) + assertNoError(generator.GenerateAMD64(10, 0, asmDirBuildPath, false)) + assertNoError(generator.GenerateAMD64(12, 0, asmDirBuildPath, false)) + + assertNoError(generator.GenerateARM64(4, 0, asmDirBuildPath, false)) + assertNoError(generator.GenerateARM64(6, 0, asmDirBuildPath, false)) + assertNoError(generator.GenerateARM64(10, 0, asmDirBuildPath, false)) + assertNoError(generator.GenerateARM64(12, 0, asmDirBuildPath, false)) var wg sync.WaitGroup for _, conf := range config.Curves { From 9651c066e73b526e1e0ed07c2051af09393b8dae Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sun, 8 Dec 2024 22:02:48 +0000 Subject: [PATCH 53/74] refactor: distinguish nb of bits in file name for generated assembly --- field/asm/element_1w_amd64.s | 222 ----------------------------- field/babybear/element_amd64.s | 2 +- field/generator/asm/amd64/build.go | 56 ++------ field/generator/asm/arm64/build.go | 57 ++------ field/generator/generator.go | 11 +- field/koalabear/element_amd64.s | 2 +- 6 files changed, 25 insertions(+), 325 deletions(-) delete mode 100644 field/asm/element_1w_amd64.s diff --git a/field/asm/element_1w_amd64.s b/field/asm/element_1w_amd64.s deleted file mode 100644 index f8f134291f..0000000000 --- a/field/asm/element_1w_amd64.s +++ /dev/null @@ -1,222 +0,0 @@ -// Code generated by gnark-crypto/generator. DO NOT EDIT. -#include "textflag.h" -#include "funcdata.h" -#include "go_asm.h" - -// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] -TEXT ·addVec(SB), NOSPLIT, $0-32 - MOVD $const_q, AX - VPBROADCASTD AX, Z3 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - -loop_1: - TESTQ BX, BX - JEQ done_2 // n == 0, we are done - VMOVDQU32 0(AX), Z0 - VMOVDQU32 0(DX), Z1 - VPADDD Z0, Z1, Z0 // a = a + b - VPSUBD Z3, Z0, Z2 // t = a - q - VPMINUD Z0, Z2, Z1 // b = min(t, a) - VMOVDQU32 Z1, 0(CX) // res = b - - // increment pointers to visit next element - ADDQ $64, AX - ADDQ $64, DX - ADDQ $64, CX - DECQ BX // decrement n - JMP loop_1 - -done_2: - RET - -// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] -TEXT ·subVec(SB), NOSPLIT, $0-32 - MOVD $const_q, AX - VPBROADCASTD AX, Z3 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - -loop_3: - TESTQ BX, BX - JEQ done_4 // n == 0, we are done - VMOVDQU32 0(AX), Z0 - VMOVDQU32 0(DX), Z1 - VPSUBD Z1, Z0, Z0 // a = a - b - VPADDD Z3, Z0, Z2 // t = a + q - VPMINUD Z0, Z2, Z1 // b = min(t, a) - VMOVDQU32 Z1, 0(CX) // res = b - - // increment pointers to visit next element - ADDQ $64, AX - ADDQ $64, DX - ADDQ $64, CX - DECQ BX // decrement n - JMP loop_3 - -done_4: - RET - -// sumVec(res *uint64, a *[]uint32, n uint64) res = sum(a[0...n]) -TEXT ·sumVec(SB), NOSPLIT, $0-24 - - // We load 8 31bits values at a time and accumulate them into an accumulator of - // 8 quadwords (64bits). The caller then needs to reduce the result mod q. - // We can safely accumulate ~2**33 31bits values into a single accumulator. - // That gives us a maximum of 2**33 * 8 = 2**36 31bits values to sum safely. - - MOVQ t+0(FP), R15 - MOVQ a+8(FP), R14 - MOVQ n+16(FP), CX - VXORPS Z2, Z2, Z2 // acc1 = 0 - VMOVDQA64 Z2, Z3 // acc2 = 0 - -loop_5: - TESTQ CX, CX - JEQ done_6 // n == 0, we are done - VPMOVZXDQ 0(R14), Z0 // load 8 31bits values in a1 - VPMOVZXDQ 32(R14), Z1 // load 8 31bits values in a2 - VPADDQ Z0, Z2, Z2 // acc1 += a1 - VPADDQ Z1, Z3, Z3 // acc2 += a2 - - // increment pointers to visit next element - ADDQ $64, R14 - DECQ CX // decrement n - JMP loop_5 - -done_6: - VPADDQ Z2, Z3, Z2 // acc1 += acc2 - VMOVDQU64 Z2, 0(R15) // res = acc1 - RET - -// mulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b[0...n] -TEXT ·mulVec(SB), NOSPLIT, $0-32 - MOVD $const_q, AX - VPBROADCASTQ AX, Z3 - MOVD $const_qInvNeg, AX - VPBROADCASTQ AX, Z4 - - // Create mask for low dword in each qword - VPCMPEQB Y0, Y0, Y0 - VPMOVZXDQ Y0, Z6 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - -loop_7: - TESTQ BX, BX - JEQ done_8 // n == 0, we are done - VPMOVZXDQ 0(AX), Z0 - VPMOVZXDQ 0(DX), Z1 - VPMULUDQ Z0, Z1, Z2 // P = a * b - VPANDQ Z6, Z2, Z5 // m = uint32(P) - VPMULUDQ Z5, Z4, Z5 // m = m * qInvNeg - VPANDQ Z6, Z5, Z5 // m = uint32(m) - VPMULUDQ Z5, Z3, Z5 // m = m * q - VPADDQ Z2, Z5, Z2 // P = P + m - VPSRLQ $32, Z2, Z2 // P = P >> 32 - VPSUBD Z3, Z2, Z5 // PL = P - q - VPMINUD Z2, Z5, Z2 // P = min(P, PL) - VPMOVQD Z2, 0(CX) // res = P - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_7 - -done_8: - RET - -// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b -TEXT ·scalarMulVec(SB), NOSPLIT, $0-32 - MOVD $const_q, AX - VPBROADCASTQ AX, Z3 - MOVD $const_qInvNeg, AX - VPBROADCASTQ AX, Z4 - - // Create mask for low dword in each qword - VPCMPEQB Y0, Y0, Y0 - VPMOVZXDQ Y0, Z6 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - VPBROADCASTD 0(DX), Z1 - -loop_9: - TESTQ BX, BX - JEQ done_10 // n == 0, we are done - VPMOVZXDQ 0(AX), Z0 - VPMULUDQ Z0, Z1, Z2 // P = a * b - VPANDQ Z6, Z2, Z5 // m = uint32(P) - VPMULUDQ Z5, Z4, Z5 // m = m * qInvNeg - VPANDQ Z6, Z5, Z5 // m = uint32(m) - VPMULUDQ Z5, Z3, Z5 // m = m * q - VPADDQ Z2, Z5, Z2 // P = P + m - VPSRLQ $32, Z2, Z2 // P = P >> 32 - VPSUBD Z3, Z2, Z5 // PL = P - q - VPMINUD Z2, Z5, Z2 // P = min(P, PL) - VPMOVQD Z2, 0(CX) // res = P - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_9 - -done_10: - RET - -// innerProdVec(t *uint64, a,b *[]uint32, n uint64) res = sum(a[0...n] * b[0...n]) -TEXT ·innerProdVec(SB), NOSPLIT, $0-32 - - // Similar to mulVec; we do most of the montgomery multiplication but don't do - // the final reduction. We accumulate the result like in sumVec and let the caller - // reduce mod q. - - MOVD $const_q, AX - VPBROADCASTQ AX, Z3 - MOVD $const_qInvNeg, AX - VPBROADCASTQ AX, Z4 - - // Create mask for low dword in each qword - VPCMPEQB Y0, Y0, Y0 - VPMOVZXDQ Y0, Z6 - VXORPS Z2, Z2, Z2 // acc = 0 - MOVQ t+0(FP), CX - MOVQ a+8(FP), R14 - MOVQ b+16(FP), R15 - MOVQ n+24(FP), BX - -loop_11: - TESTQ BX, BX - JEQ done_12 // n == 0, we are done - VPMOVZXDQ 0(R14), Z0 - VPMOVZXDQ 0(R15), Z1 - VPMULUDQ Z0, Z1, Z7 // P = a * b - VPANDQ Z6, Z7, Z5 // m = uint32(P) - VPMULUDQ Z5, Z4, Z5 // m = m * qInvNeg - VPANDQ Z6, Z5, Z5 // m = uint32(m) - VPMULUDQ Z5, Z3, Z5 // m = m * q - VPADDQ Z7, Z5, Z7 // P = P + m - VPSRLQ $32, Z7, Z7 // P = P >> 32 - - // accumulate P into acc, P is in [0, 2q] on 32bits max - VPADDQ Z7, Z2, Z2 // acc += P - - // increment pointers to visit next element - ADDQ $32, R14 - ADDQ $32, R15 - DECQ BX // decrement n - JMP loop_11 - -done_12: - VMOVDQU64 Z2, 0(CX) // res = acc - RET diff --git a/field/babybear/element_amd64.s b/field/babybear/element_amd64.s index cde280d193..80392ac8f5 100644 --- a/field/babybear/element_amd64.s +++ b/field/babybear/element_amd64.s @@ -17,5 +17,5 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT // We include the hash to force the Go compiler to recompile: 11172894854395138580 -#include "../asm/element_1w_amd64.s" +#include "../asm/element_31b_amd64.s" diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index 0af3fda97f..aeb1cd7a7e 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -17,20 +17,13 @@ package amd64 import ( "fmt" - "hash/fnv" "io" - "os" - "path/filepath" "strings" "github.com/consensys/bavard/amd64" - "github.com/consensys/gnark-crypto/field/generator/config" ) const SmallModulus = 6 -const ( - ElementASMFileName = "element_%dw_amd64.s" -) func NewFFAmd64(w io.Writer, nbWords int) *FFAmd64 { F := &FFAmd64{ @@ -224,46 +217,6 @@ func (f *FFAmd64) mu() string { return "$const_mu" } -func GenerateFieldWrapper(w io.Writer, F *config.FieldConfig, asmDirBuildPath, asmDirIncludePath string) error { - // for each field we generate the defines for the modulus and the montgomery constant - f := NewFFAmd64(w, F.NbWords) - - // we add the defines first, then the common asm, then the global variable section - // to enable correct compilations with #include in order. - f.WriteLn("") - - hashAndInclude := func(fileName string) error { - // we hash the file content and include the hash in comment of the generated file - // to force the Go compiler to recompile the file if the content has changed - fData, err := os.ReadFile(filepath.Join(asmDirBuildPath, fileName)) - if err != nil { - return err - } - // hash the file using FNV - hasher := fnv.New64() - hasher.Write(fData) - hash := hasher.Sum64() - - f.WriteLn("// Code generated by gnark-crypto/generator. DO NOT EDIT.") - f.WriteLn(fmt.Sprintf("// We include the hash to force the Go compiler to recompile: %d", hash)) - includePath := filepath.Join(asmDirIncludePath, fileName) - // on windows, we replace the "\" by "/" - if filepath.Separator == '\\' { - includePath = strings.ReplaceAll(includePath, "\\", "/") - } - f.WriteLn(fmt.Sprintf("#include \"%s\"\n", includePath)) - - return nil - } - - toInclude := fmt.Sprintf(ElementASMFileName, F.NbWords) - if err := hashAndInclude(toInclude); err != nil { - return err - } - - return nil -} - // GenerateCommonASM generates assembly code for the base field provided to goff // see internal/templates/ops* func GenerateCommonASM(w io.Writer, nbWords, nbBits int, hasVector bool) error { @@ -333,3 +286,12 @@ func GenerateF31ASM(f *FFAmd64, hasVector bool) error { return nil } + +func ElementASMFileName(nbWords, nbBits int) string { + const nameW1 = "element_%db_amd64.s" + const nameWN = "element_%dw_amd64.s" + if nbWords == 1 { + return fmt.Sprintf(nameW1, nbBits) + } + return fmt.Sprintf(nameWN, nbWords) +} diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index 20519c6019..ed872a7940 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -2,18 +2,10 @@ package arm64 import ( "fmt" - "hash/fnv" "io" - "os" - "path/filepath" "strings" "github.com/consensys/bavard/arm64" - "github.com/consensys/gnark-crypto/field/generator/config" -) - -const ( - ElementASMFileName = "element_%dw_arm64.s" ) type defineFn func(args ...arm64.Register) @@ -52,46 +44,6 @@ type FFArm64 struct { mDefines map[string]defineFn } -func GenerateFieldWrapper(w io.Writer, F *config.FieldConfig, asmDirBuildPath, asmDirIncludePath string) error { - // for each field we generate the defines for the modulus and the montgomery constant - f := NewFFArm64(w, F.NbWords) - - // we add the defines first, then the common asm, then the global variable section - // to enable correct compilations with #include in order. - f.WriteLn("") - - hashAndInclude := func(fileName string) error { - // we hash the file content and include the hash in comment of the generated file - // to force the Go compiler to recompile the file if the content has changed - fData, err := os.ReadFile(filepath.Join(asmDirBuildPath, fileName)) - if err != nil { - return err - } - // hash the file using FNV - hasher := fnv.New64() - hasher.Write(fData) - hash := hasher.Sum64() - - f.WriteLn("// Code generated by gnark-crypto/generator. DO NOT EDIT.") - f.WriteLn(fmt.Sprintf("// We include the hash to force the Go compiler to recompile: %d", hash)) - includePath := filepath.Join(asmDirIncludePath, fileName) - // on windows, we replace the "\" by "/" - if filepath.Separator == '\\' { - includePath = strings.ReplaceAll(includePath, "\\", "/") - } - f.WriteLn(fmt.Sprintf("#include \"%s\"\n", includePath)) - - return nil - } - - toInclude := fmt.Sprintf(ElementASMFileName, F.NbWords) - if err := hashAndInclude(toInclude); err != nil { - return err - } - - return nil -} - // GenerateCommonASM generates assembly code for the base field provided to goff // see internal/templates/ops* func GenerateCommonASM(w io.Writer, nbWords, nbBits int, hasVector bool) error { @@ -252,3 +204,12 @@ func (f *FFArm64) qAt(index int) string { func (f *FFArm64) qInv0() string { return "$const_qInvNeg" } + +func ElementASMFileName(nbWords, nbBits int) string { + const nameW1 = "element_%db_arm64.s" + const nameWN = "element_%dw_arm64.s" + if nbWords == 1 { + return fmt.Sprintf(nameW1, nbBits) + } + return fmt.Sprintf(nameWN, nbWords) +} diff --git a/field/generator/generator.go b/field/generator/generator.go index a3b5513949..7cb08b478a 100644 --- a/field/generator/generator.go +++ b/field/generator/generator.go @@ -112,14 +112,14 @@ func GenerateFF(F *config.FieldConfig, outputDir, asmDirBuildPath, asmDirInclude var err error if F.GenerateOpsAMD64 { - amd64d, err = hashAndInclude(asmDirBuildPath, asmDirIncludePath, amd64.ElementASMFileName, F.NbWords) + amd64d, err = hashAndInclude(asmDirBuildPath, asmDirIncludePath, amd64.ElementASMFileName(F.NbWords, F.NbBits)) if err != nil { return err } } if F.GenerateOpsARM64 { - arm64d, err = hashAndInclude(asmDirBuildPath, asmDirIncludePath, arm64.ElementASMFileName, F.NbWords) + arm64d, err = hashAndInclude(asmDirBuildPath, asmDirIncludePath, arm64.ElementASMFileName(F.NbWords, F.NbBits)) if err != nil { return err } @@ -208,8 +208,7 @@ type ASMWrapperData struct { Hash string } -func hashAndInclude(asmDirBuildPath, asmDirIncludePath, fileName string, nbWords int) (data ASMWrapperData, err error) { - fileName = fmt.Sprintf(fileName, nbWords) +func hashAndInclude(asmDirBuildPath, asmDirIncludePath, fileName string) (data ASMWrapperData, err error) { // we hash the file content and include the hash in comment of the generated file // to force the Go compiler to recompile the file if the content has changed fData, err := os.ReadFile(filepath.Join(asmDirBuildPath, fileName)) @@ -245,7 +244,7 @@ func shorten(input string) string { func GenerateARM64(nbWords, nbBits int, asmDir string, hasVector bool) error { os.MkdirAll(asmDir, 0755) - pathSrc := filepath.Join(asmDir, fmt.Sprintf(arm64.ElementASMFileName, nbWords)) + pathSrc := filepath.Join(asmDir, arm64.ElementASMFileName(nbWords, nbBits)) fmt.Println("generating", pathSrc) f, err := os.Create(pathSrc) @@ -273,7 +272,7 @@ func GenerateARM64(nbWords, nbBits int, asmDir string, hasVector bool) error { func GenerateAMD64(nbWords, nbBits int, asmDir string, hasVector bool) error { os.MkdirAll(asmDir, 0755) - pathSrc := filepath.Join(asmDir, fmt.Sprintf(amd64.ElementASMFileName, nbWords)) + pathSrc := filepath.Join(asmDir, amd64.ElementASMFileName(nbWords, nbBits)) fmt.Println("generating", pathSrc) f, err := os.Create(pathSrc) diff --git a/field/koalabear/element_amd64.s b/field/koalabear/element_amd64.s index cde280d193..80392ac8f5 100644 --- a/field/koalabear/element_amd64.s +++ b/field/koalabear/element_amd64.s @@ -17,5 +17,5 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT // We include the hash to force the Go compiler to recompile: 11172894854395138580 -#include "../asm/element_1w_amd64.s" +#include "../asm/element_31b_amd64.s" From 31b74f227d6e24bfbd97c612a940d72ddcc71ac6 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sun, 8 Dec 2024 22:03:00 +0000 Subject: [PATCH 54/74] feat: add missing file --- field/asm/element_31b_amd64.s | 222 ++++++++++++++++++++++++++++++++++ 1 file changed, 222 insertions(+) create mode 100644 field/asm/element_31b_amd64.s diff --git a/field/asm/element_31b_amd64.s b/field/asm/element_31b_amd64.s new file mode 100644 index 0000000000..f8f134291f --- /dev/null +++ b/field/asm/element_31b_amd64.s @@ -0,0 +1,222 @@ +// Code generated by gnark-crypto/generator. DO NOT EDIT. +#include "textflag.h" +#include "funcdata.h" +#include "go_asm.h" + +// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] +TEXT ·addVec(SB), NOSPLIT, $0-32 + MOVD $const_q, AX + VPBROADCASTD AX, Z3 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + +loop_1: + TESTQ BX, BX + JEQ done_2 // n == 0, we are done + VMOVDQU32 0(AX), Z0 + VMOVDQU32 0(DX), Z1 + VPADDD Z0, Z1, Z0 // a = a + b + VPSUBD Z3, Z0, Z2 // t = a - q + VPMINUD Z0, Z2, Z1 // b = min(t, a) + VMOVDQU32 Z1, 0(CX) // res = b + + // increment pointers to visit next element + ADDQ $64, AX + ADDQ $64, DX + ADDQ $64, CX + DECQ BX // decrement n + JMP loop_1 + +done_2: + RET + +// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] +TEXT ·subVec(SB), NOSPLIT, $0-32 + MOVD $const_q, AX + VPBROADCASTD AX, Z3 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + +loop_3: + TESTQ BX, BX + JEQ done_4 // n == 0, we are done + VMOVDQU32 0(AX), Z0 + VMOVDQU32 0(DX), Z1 + VPSUBD Z1, Z0, Z0 // a = a - b + VPADDD Z3, Z0, Z2 // t = a + q + VPMINUD Z0, Z2, Z1 // b = min(t, a) + VMOVDQU32 Z1, 0(CX) // res = b + + // increment pointers to visit next element + ADDQ $64, AX + ADDQ $64, DX + ADDQ $64, CX + DECQ BX // decrement n + JMP loop_3 + +done_4: + RET + +// sumVec(res *uint64, a *[]uint32, n uint64) res = sum(a[0...n]) +TEXT ·sumVec(SB), NOSPLIT, $0-24 + + // We load 8 31bits values at a time and accumulate them into an accumulator of + // 8 quadwords (64bits). The caller then needs to reduce the result mod q. + // We can safely accumulate ~2**33 31bits values into a single accumulator. + // That gives us a maximum of 2**33 * 8 = 2**36 31bits values to sum safely. + + MOVQ t+0(FP), R15 + MOVQ a+8(FP), R14 + MOVQ n+16(FP), CX + VXORPS Z2, Z2, Z2 // acc1 = 0 + VMOVDQA64 Z2, Z3 // acc2 = 0 + +loop_5: + TESTQ CX, CX + JEQ done_6 // n == 0, we are done + VPMOVZXDQ 0(R14), Z0 // load 8 31bits values in a1 + VPMOVZXDQ 32(R14), Z1 // load 8 31bits values in a2 + VPADDQ Z0, Z2, Z2 // acc1 += a1 + VPADDQ Z1, Z3, Z3 // acc2 += a2 + + // increment pointers to visit next element + ADDQ $64, R14 + DECQ CX // decrement n + JMP loop_5 + +done_6: + VPADDQ Z2, Z3, Z2 // acc1 += acc2 + VMOVDQU64 Z2, 0(R15) // res = acc1 + RET + +// mulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b[0...n] +TEXT ·mulVec(SB), NOSPLIT, $0-32 + MOVD $const_q, AX + VPBROADCASTQ AX, Z3 + MOVD $const_qInvNeg, AX + VPBROADCASTQ AX, Z4 + + // Create mask for low dword in each qword + VPCMPEQB Y0, Y0, Y0 + VPMOVZXDQ Y0, Z6 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + +loop_7: + TESTQ BX, BX + JEQ done_8 // n == 0, we are done + VPMOVZXDQ 0(AX), Z0 + VPMOVZXDQ 0(DX), Z1 + VPMULUDQ Z0, Z1, Z2 // P = a * b + VPANDQ Z6, Z2, Z5 // m = uint32(P) + VPMULUDQ Z5, Z4, Z5 // m = m * qInvNeg + VPANDQ Z6, Z5, Z5 // m = uint32(m) + VPMULUDQ Z5, Z3, Z5 // m = m * q + VPADDQ Z2, Z5, Z2 // P = P + m + VPSRLQ $32, Z2, Z2 // P = P >> 32 + VPSUBD Z3, Z2, Z5 // PL = P - q + VPMINUD Z2, Z5, Z2 // P = min(P, PL) + VPMOVQD Z2, 0(CX) // res = P + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_7 + +done_8: + RET + +// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b +TEXT ·scalarMulVec(SB), NOSPLIT, $0-32 + MOVD $const_q, AX + VPBROADCASTQ AX, Z3 + MOVD $const_qInvNeg, AX + VPBROADCASTQ AX, Z4 + + // Create mask for low dword in each qword + VPCMPEQB Y0, Y0, Y0 + VPMOVZXDQ Y0, Z6 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + VPBROADCASTD 0(DX), Z1 + +loop_9: + TESTQ BX, BX + JEQ done_10 // n == 0, we are done + VPMOVZXDQ 0(AX), Z0 + VPMULUDQ Z0, Z1, Z2 // P = a * b + VPANDQ Z6, Z2, Z5 // m = uint32(P) + VPMULUDQ Z5, Z4, Z5 // m = m * qInvNeg + VPANDQ Z6, Z5, Z5 // m = uint32(m) + VPMULUDQ Z5, Z3, Z5 // m = m * q + VPADDQ Z2, Z5, Z2 // P = P + m + VPSRLQ $32, Z2, Z2 // P = P >> 32 + VPSUBD Z3, Z2, Z5 // PL = P - q + VPMINUD Z2, Z5, Z2 // P = min(P, PL) + VPMOVQD Z2, 0(CX) // res = P + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_9 + +done_10: + RET + +// innerProdVec(t *uint64, a,b *[]uint32, n uint64) res = sum(a[0...n] * b[0...n]) +TEXT ·innerProdVec(SB), NOSPLIT, $0-32 + + // Similar to mulVec; we do most of the montgomery multiplication but don't do + // the final reduction. We accumulate the result like in sumVec and let the caller + // reduce mod q. + + MOVD $const_q, AX + VPBROADCASTQ AX, Z3 + MOVD $const_qInvNeg, AX + VPBROADCASTQ AX, Z4 + + // Create mask for low dword in each qword + VPCMPEQB Y0, Y0, Y0 + VPMOVZXDQ Y0, Z6 + VXORPS Z2, Z2, Z2 // acc = 0 + MOVQ t+0(FP), CX + MOVQ a+8(FP), R14 + MOVQ b+16(FP), R15 + MOVQ n+24(FP), BX + +loop_11: + TESTQ BX, BX + JEQ done_12 // n == 0, we are done + VPMOVZXDQ 0(R14), Z0 + VPMOVZXDQ 0(R15), Z1 + VPMULUDQ Z0, Z1, Z7 // P = a * b + VPANDQ Z6, Z7, Z5 // m = uint32(P) + VPMULUDQ Z5, Z4, Z5 // m = m * qInvNeg + VPANDQ Z6, Z5, Z5 // m = uint32(m) + VPMULUDQ Z5, Z3, Z5 // m = m * q + VPADDQ Z7, Z5, Z7 // P = P + m + VPSRLQ $32, Z7, Z7 // P = P >> 32 + + // accumulate P into acc, P is in [0, 2q] on 32bits max + VPADDQ Z7, Z2, Z2 // acc += P + + // increment pointers to visit next element + ADDQ $32, R14 + ADDQ $32, R15 + DECQ BX // decrement n + JMP loop_11 + +done_12: + VMOVDQU64 Z2, 0(CX) // res = acc + RET From 0f549b5d5ebe2d4502b3e8a3277a8d3d2e2c9b73 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 9 Dec 2024 02:41:19 +0000 Subject: [PATCH 55/74] test: fix broken integration test --- field/generator/asm/amd64/build.go | 5 ++++- field/generator/asm/arm64/build.go | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index aeb1cd7a7e..de4da5e23f 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -291,7 +291,10 @@ func ElementASMFileName(nbWords, nbBits int) string { const nameW1 = "element_%db_amd64.s" const nameWN = "element_%dw_amd64.s" if nbWords == 1 { - return fmt.Sprintf(nameW1, nbBits) + if nbBits >= 32 { + panic("not implemented") + } + return fmt.Sprintf(nameW1, 31) } return fmt.Sprintf(nameWN, nbWords) } diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index ed872a7940..c0d617410a 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -209,7 +209,10 @@ func ElementASMFileName(nbWords, nbBits int) string { const nameW1 = "element_%db_arm64.s" const nameWN = "element_%dw_arm64.s" if nbWords == 1 { - return fmt.Sprintf(nameW1, nbBits) + if nbBits >= 32 { + panic("not implemented") + } + return fmt.Sprintf(nameW1, 31) } return fmt.Sprintf(nameWN, nbWords) } From 34fdd6e01b62e6f84a62ce2ca281ff516cf01e73 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 10 Dec 2024 08:45:15 -0600 Subject: [PATCH 56/74] chore: run go mod tidy --- go.sum | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/go.sum b/go.sum index d481d5117f..3aa78d3bb5 100644 --- a/go.sum +++ b/go.sum @@ -55,18 +55,6 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/consensys/bavard v0.1.23-0.20241022191302-a6fdcdb6e8f3 h1:8gPxbjhwhxXTakOXII32eLlAFLlYImoENa3uQ6iP+go= -github.com/consensys/bavard v0.1.23-0.20241022191302-a6fdcdb6e8f3/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= -github.com/consensys/bavard v0.1.23-0.20241207235124-babad3045f79 h1:lhIivWq5SgulQUNtgUugSMqcIpQNZkB5EPD/CwF3r9w= -github.com/consensys/bavard v0.1.23-0.20241207235124-babad3045f79/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= -github.com/consensys/bavard v0.1.23-0.20241207235803-84aa6b3d4724 h1:wBPDHYgf1QvlnW/7gVZVBYVgkKjYV1J8Hbsa5qwvESs= -github.com/consensys/bavard v0.1.23-0.20241207235803-84aa6b3d4724/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= -github.com/consensys/bavard v0.1.23-0.20241208000453-1b3c9246dcd6 h1:dm/VT++/p4tq8FLR/8z361AvWPD9dcp6xXebPLaEdZo= -github.com/consensys/bavard v0.1.23-0.20241208000453-1b3c9246dcd6/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= -github.com/consensys/bavard v0.1.23-0.20241208033340-073352297c17 h1:7fr9/A1Nm0L67XO33mdSTPD+2prj7VmlQ63YE9ujHW8= -github.com/consensys/bavard v0.1.23-0.20241208033340-073352297c17/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= -github.com/consensys/bavard v0.1.23-0.20241208043834-8013eddc8088 h1:5fIHEbNpqWy7NWRzI8IInFCVSGyhH2BF4wpmY2XXg1k= -github.com/consensys/bavard v0.1.23-0.20241208043834-8013eddc8088/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/bavard v0.1.23-0.20241208050306-51bece05ad82 h1:S2g4cxvQWO4KXQygLFtoyx4Q3lqoKRk/zJFFOYtJiJw= github.com/consensys/bavard v0.1.23-0.20241208050306-51bece05ad82/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= From 21b9b80d07388d3d43f8f99be427fb54018f4fc6 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 10 Dec 2024 14:47:37 -0600 Subject: [PATCH 57/74] chore: re run go generate to update doc --- field/babybear/doc.go | 2 +- field/koalabear/doc.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/field/babybear/doc.go b/field/babybear/doc.go index 6a59371928..e134942b43 100644 --- a/field/babybear/doc.go +++ b/field/babybear/doc.go @@ -7,7 +7,7 @@ // // The API is similar to math/big (big.Int), but the operations are significantly faster (up to 20x). // -// Additionally babybear.Vector offers an API to manipulate []Element. +// Additionally babybear.Vector offers an API to manipulate []Element using AVX512 instructions if available. // // The modulus is hardcoded in all the operations. // diff --git a/field/koalabear/doc.go b/field/koalabear/doc.go index 4abddfb000..60c2390f6a 100644 --- a/field/koalabear/doc.go +++ b/field/koalabear/doc.go @@ -7,7 +7,7 @@ // // The API is similar to math/big (big.Int), but the operations are significantly faster (up to 20x). // -// Additionally koalabear.Vector offers an API to manipulate []Element. +// Additionally koalabear.Vector offers an API to manipulate []Element using AVX512 instructions if available. // // The modulus is hardcoded in all the operations. // From 38324d395a671576b1ce29a63d844b3063a75ea7 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 10 Dec 2024 15:03:36 -0600 Subject: [PATCH 58/74] feat: prepare skeletton for NEON on F31 --- field/asm/element_31b_arm64.s | 5 +++ field/babybear/doc.go | 2 +- field/babybear/element_arm64.s | 10 +++++ field/babybear/vector_arm64.go | 45 +++++++++++++++++++ field/babybear/vector_purego.go | 2 +- field/generator/asm/arm64/build.go | 24 ++++++++++ field/generator/config/field_config.go | 4 +- field/generator/generator.go | 6 +-- .../templates/element/vector_ops_asm.go | 2 + field/internal/main.go | 5 +++ field/koalabear/doc.go | 2 +- field/koalabear/element_arm64.s | 10 +++++ field/koalabear/vector_arm64.go | 45 +++++++++++++++++++ field/koalabear/vector_purego.go | 2 +- 14 files changed, 155 insertions(+), 9 deletions(-) create mode 100644 field/asm/element_31b_arm64.s create mode 100644 field/babybear/element_arm64.s create mode 100644 field/babybear/vector_arm64.go create mode 100644 field/koalabear/element_arm64.s create mode 100644 field/koalabear/vector_arm64.go diff --git a/field/asm/element_31b_arm64.s b/field/asm/element_31b_arm64.s new file mode 100644 index 0000000000..f01ac64962 --- /dev/null +++ b/field/asm/element_31b_arm64.s @@ -0,0 +1,5 @@ +// Code generated by gnark-crypto/generator. DO NOT EDIT. +#include "textflag.h" +#include "funcdata.h" +#include "go_asm.h" + diff --git a/field/babybear/doc.go b/field/babybear/doc.go index e134942b43..fc116bd84f 100644 --- a/field/babybear/doc.go +++ b/field/babybear/doc.go @@ -7,7 +7,7 @@ // // The API is similar to math/big (big.Int), but the operations are significantly faster (up to 20x). // -// Additionally babybear.Vector offers an API to manipulate []Element using AVX512 instructions if available. +// Additionally babybear.Vector offers an API to manipulate []Element using AVX512/NEON instructions if available. // // The modulus is hardcoded in all the operations. // diff --git a/field/babybear/element_arm64.s b/field/babybear/element_arm64.s new file mode 100644 index 0000000000..2b0c2ab64c --- /dev/null +++ b/field/babybear/element_arm64.s @@ -0,0 +1,10 @@ +//go:build !purego + +// Copyright 2020-2024 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 526586546912891733 +#include "../asm/element_31b_arm64.s" + diff --git a/field/babybear/vector_arm64.go b/field/babybear/vector_arm64.go new file mode 100644 index 0000000000..0f91a0fc74 --- /dev/null +++ b/field/babybear/vector_arm64.go @@ -0,0 +1,45 @@ +//go:build !purego + +// Copyright 2020-2024 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package babybear + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/field/babybear/vector_purego.go b/field/babybear/vector_purego.go index 3d9177a0ef..8164830b26 100644 --- a/field/babybear/vector_purego.go +++ b/field/babybear/vector_purego.go @@ -1,4 +1,4 @@ -//go:build purego || !amd64 +//go:build purego || (!amd64 && !arm64) // Copyright 2020-2024 Consensys Software Inc. // Licensed under the Apache License, Version 2.0. See the LICENSE file for details. diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index c0d617410a..90befefabd 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -55,6 +55,15 @@ func GenerateCommonASM(w io.Writer, nbWords, nbBits int, hasVector bool) error { f.WriteLn("#include \"go_asm.h\"") f.WriteLn("") + if nbWords == 1 { + if nbBits == 31 { + return GenerateF31ASM(f, hasVector) + } else { + panic("not implemented") + } + + } + if f.NbWords%2 != 0 { panic("NbWords must be even") } @@ -216,3 +225,18 @@ func ElementASMFileName(nbWords, nbBits int) string { } return fmt.Sprintf(nameWN, nbWords) } + +func GenerateF31ASM(f *FFArm64, hasVector bool) error { + if !hasVector { + return nil // nothing for now. + } + + // f.generateAddVecF31() + // f.generateSubVecF31() + // f.generateSumVecF31() + // f.generateMulVecF31() + // f.generateScalarMulVecF31() + // f.generateInnerProdVecF31() + + return nil +} diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index 20314f4098..5f0c4ac939 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -290,8 +290,8 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) // asm code generation for moduli with more than 6 words can be optimized further F.GenerateOpsAMD64 = F.F31 || (F.NoCarry && F.NbWords <= 12 && F.NbWords > 1) F.GenerateVectorOpsAMD64 = F.F31 || (F.GenerateOpsAMD64 && F.NbWords == 4 && F.NbBits > 225) - F.GenerateOpsARM64 = F.GenerateOpsAMD64 && (F.NbWords%2 == 0) - F.GenerateVectorOpsARM64 = false + F.GenerateOpsARM64 = F.F31 || (F.GenerateOpsAMD64 && (F.NbWords%2 == 0)) + F.GenerateVectorOpsARM64 = F.F31 // setting Mu 2^288 / q if F.NbWords == 4 { diff --git a/field/generator/generator.go b/field/generator/generator.go index 99f8540f45..82ea888f0f 100644 --- a/field/generator/generator.go +++ b/field/generator/generator.go @@ -142,7 +142,6 @@ func GenerateFF(F *config.FieldConfig, outputDir, asmDirBuildPath, asmDirInclude if F.F31 { pureGoBuildTag = "" // always generate pure go for F31 - pureGoVectorBuildTag = "purego || (!amd64)" } var g errgroup.Group @@ -158,13 +157,14 @@ func GenerateFF(F *config.FieldConfig, outputDir, asmDirBuildPath, asmDirInclude g.Go(generate("element_arm64.s", []string{element.IncludeASM}, Only(F.GenerateOpsARM64), WithBuildTag("!purego"), WithData(arm64d))) g.Go(generate("element_amd64.go", []string{element.OpsAMD64, element.MulDoc}, Only(F.GenerateOpsAMD64 && !F.F31), WithBuildTag("!purego"))) - g.Go(generate("element_arm64.go", []string{element.OpsARM64, element.MulNoCarry, element.Reduce}, Only(F.GenerateOpsARM64), WithBuildTag("!purego"))) + g.Go(generate("element_arm64.go", []string{element.OpsARM64, element.MulNoCarry, element.Reduce}, Only(F.GenerateOpsARM64 && !F.F31), WithBuildTag("!purego"))) g.Go(generate("element_purego.go", []string{element.OpsNoAsm, element.MulCIOS, element.MulNoCarry, element.Reduce, element.MulDoc}, WithBuildTag(pureGoBuildTag))) g.Go(generate("vector_amd64.go", []string{element.VectorOpsAmd64}, Only(F.GenerateVectorOpsAMD64 && !F.F31), WithBuildTag("!purego"))) g.Go(generate("vector_amd64.go", []string{element.VectorOpsAmd64F31}, Only(F.GenerateVectorOpsAMD64 && F.F31), WithBuildTag("!purego"))) - g.Go(generate("vector_arm64.go", []string{element.VectorOpsArm64}, Only(F.GenerateVectorOpsARM64), WithBuildTag("!purego"))) + g.Go(generate("vector_arm64.go", []string{element.VectorOpsArm64}, Only(F.GenerateVectorOpsARM64 && !F.F31), WithBuildTag("!purego"))) + g.Go(generate("vector_arm64.go", []string{element.VectorOpsArm64F31}, Only(F.GenerateVectorOpsARM64 && F.F31), WithBuildTag("!purego"))) g.Go(generate("vector_purego.go", []string{element.VectorOpsPureGo}, WithBuildTag(pureGoVectorBuildTag))) diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go index aa00830dfd..bc15980378 100644 --- a/field/generator/internal/templates/element/vector_ops_asm.go +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -146,6 +146,8 @@ func mulVec(res, a, b *{{.ElementName}}, n uint64, qInvNeg uint64) const VectorOpsArm64 = VectorOpsPureGo +const VectorOpsArm64F31 = VectorOpsPureGo + const VectorOpsAmd64F31 = ` //go:noescape diff --git a/field/internal/main.go b/field/internal/main.go index 29d84091c1..bd0d50c245 100644 --- a/field/internal/main.go +++ b/field/internal/main.go @@ -30,6 +30,11 @@ func main() { panic(err) } + // generate arm + if err := generator.GenerateARM64(1, 31, asmDir, true); err != nil { + panic(err) + } + for _, f := range fields { fc, err := config.NewFieldConfig(f.name, "Element", f.modulus, true) if err != nil { diff --git a/field/koalabear/doc.go b/field/koalabear/doc.go index 60c2390f6a..c102c15c78 100644 --- a/field/koalabear/doc.go +++ b/field/koalabear/doc.go @@ -7,7 +7,7 @@ // // The API is similar to math/big (big.Int), but the operations are significantly faster (up to 20x). // -// Additionally koalabear.Vector offers an API to manipulate []Element using AVX512 instructions if available. +// Additionally koalabear.Vector offers an API to manipulate []Element using AVX512/NEON instructions if available. // // The modulus is hardcoded in all the operations. // diff --git a/field/koalabear/element_arm64.s b/field/koalabear/element_arm64.s new file mode 100644 index 0000000000..2b0c2ab64c --- /dev/null +++ b/field/koalabear/element_arm64.s @@ -0,0 +1,10 @@ +//go:build !purego + +// Copyright 2020-2024 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// We include the hash to force the Go compiler to recompile: 526586546912891733 +#include "../asm/element_31b_arm64.s" + diff --git a/field/koalabear/vector_arm64.go b/field/koalabear/vector_arm64.go new file mode 100644 index 0000000000..733319e548 --- /dev/null +++ b/field/koalabear/vector_arm64.go @@ -0,0 +1,45 @@ +//go:build !purego + +// Copyright 2020-2024 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package koalabear + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/field/koalabear/vector_purego.go b/field/koalabear/vector_purego.go index c96e144f86..fa9715eaf5 100644 --- a/field/koalabear/vector_purego.go +++ b/field/koalabear/vector_purego.go @@ -1,4 +1,4 @@ -//go:build purego || !amd64 +//go:build purego || (!amd64 && !arm64) // Copyright 2020-2024 Consensys Software Inc. // Licensed under the Apache License, Version 2.0. See the LICENSE file for details. From fd46a1e00fab6e19996e0e3a078531ecbb6d938a Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 10 Dec 2024 16:15:57 -0600 Subject: [PATCH 59/74] checkpoint --- field/asm/element_31b_arm64.s | 19 ++++++ field/babybear/element_arm64.s | 2 +- field/babybear/vector_arm64.go | 22 ++++++- field/generator/asm/arm64/build.go | 2 +- field/generator/asm/arm64/element_vec_F31.go | 57 ++++++++++++++++++ .../templates/element/vector_ops_asm.go | 59 ++++++++++++++++++- field/koalabear/element_arm64.s | 2 +- field/koalabear/vector_arm64.go | 22 ++++++- go.mod | 4 +- 9 files changed, 182 insertions(+), 7 deletions(-) create mode 100644 field/generator/asm/arm64/element_vec_F31.go diff --git a/field/asm/element_31b_arm64.s b/field/asm/element_31b_arm64.s index f01ac64962..9906479334 100644 --- a/field/asm/element_31b_arm64.s +++ b/field/asm/element_31b_arm64.s @@ -3,3 +3,22 @@ #include "funcdata.h" #include "go_asm.h" +// addVec(res, a, b *Element, n uint64) +TEXT ·addVec(SB), NOFRAME|NOSPLIT, $0-32 + MOVD $const_q, R6 + LDP res+0(FP), (R0, R1) + LDP b+16(FP), (R2, R3) + +loop1: + CBZ R3, done2 + MOVWU.P 4(R1), R5 + MOVWU.P 4(R2), R4 + ADD R5, R4, R4 + SUBS R6, R4, R7 + CSEL CS, R7, R4, R7 + MOVWU.P R4, 4(R0) + SUB $1, R3, R3 + JMP loop1 + +done2: + RET diff --git a/field/babybear/element_arm64.s b/field/babybear/element_arm64.s index 2b0c2ab64c..6ca9c99f0d 100644 --- a/field/babybear/element_arm64.s +++ b/field/babybear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 526586546912891733 +// We include the hash to force the Go compiler to recompile: 8072276761764543562 #include "../asm/element_31b_arm64.s" diff --git a/field/babybear/vector_arm64.go b/field/babybear/vector_arm64.go index 0f91a0fc74..d3c36b1039 100644 --- a/field/babybear/vector_arm64.go +++ b/field/babybear/vector_arm64.go @@ -7,10 +7,30 @@ package babybear +//go:noescape +func addVec(res, a, b *Element, n uint64) + // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + + const blockSize = 1 + addVec(&(*vector)[0], &a[0], &b[0], n/blockSize) + for i := 0; i < len(*vector); i++ { + (*vector)[i][0] %= q + } + // if n % blockSize != 0 { + // // call addVecGeneric on the rest + // start := n - n % blockSize + // addVecGeneric((*vector)[start:], a[start:], b[start:]) + // } } // Sub subtracts two vectors element-wise and stores the result in self. diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index 90befefabd..ce2759044c 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -231,7 +231,7 @@ func GenerateF31ASM(f *FFArm64, hasVector bool) error { return nil // nothing for now. } - // f.generateAddVecF31() + f.generateAddVecF31() // f.generateSubVecF31() // f.generateSumVecF31() // f.generateMulVecF31() diff --git a/field/generator/asm/arm64/element_vec_F31.go b/field/generator/asm/arm64/element_vec_F31.go new file mode 100644 index 0000000000..76b8c4b4e9 --- /dev/null +++ b/field/generator/asm/arm64/element_vec_F31.go @@ -0,0 +1,57 @@ +package arm64 + +func (f *FFArm64) generateAddVecF31() { + f.Comment("addVec(res, a, b *Element, n uint64)") + registers := f.FnHeader("addVec", 0, 32) + defer f.AssertCleanStack(0, 0) + + // registers + resPtr := registers.Pop() + aPtr := registers.Pop() + bPtr := registers.Pop() + n := registers.Pop() + + b := registers.Pop() + a := registers.Pop() + q := registers.Pop() + t := registers.Pop() + + f.MOVD("$const_q", q) + + // labels + loop := f.NewLabel("loop") + done := f.NewLabel("done") + + // load arguments + f.LDP("res+0(FP)", resPtr, aPtr) + f.LDP("b+16(FP)", bPtr, n) + + f.LABEL(loop) + + f.CBZ(n, done) + + // load a + f.MOVWUP_Load(4, aPtr, a) + // load b + f.MOVWUP_Load(4, bPtr, b) + + // res = a + b + f.ADD(a, b, b) + + // t = res - q + f.SUBS(q, b, t) + + // t = min(t, res) + f.CSEL("CS", t, b, t) + + // res = t + f.MOVWUP_Store(b, resPtr, 4) + + // decrement n + f.SUB(1, n, n) + f.JMP(loop) + + f.LABEL(done) + f.RET() + +} diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go index bc15980378..9acf1b8075 100644 --- a/field/generator/internal/templates/element/vector_ops_asm.go +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -146,7 +146,64 @@ func mulVec(res, a, b *{{.ElementName}}, n uint64, qInvNeg uint64) const VectorOpsArm64 = VectorOpsPureGo -const VectorOpsArm64F31 = VectorOpsPureGo +const VectorOpsArm64F31 = ` +//go:noescape +func addVec(res, a, b *{{.ElementName}}, n uint64) + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + + const blockSize = 1 + addVec(&(*vector)[0], &a[0], &b[0], n/blockSize) + for i := 0; i < len(*vector); i++ { + (*vector)[i][0] %= q + } + // if n % blockSize != 0 { + // // call addVecGeneric on the rest + // start := n - n % blockSize + // addVecGeneric((*vector)[start:], a[start:], b[start:]) + // } +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res {{.ElementName}}) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res {{.ElementName}}) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} +` const VectorOpsAmd64F31 = ` diff --git a/field/koalabear/element_arm64.s b/field/koalabear/element_arm64.s index 2b0c2ab64c..6ca9c99f0d 100644 --- a/field/koalabear/element_arm64.s +++ b/field/koalabear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 526586546912891733 +// We include the hash to force the Go compiler to recompile: 8072276761764543562 #include "../asm/element_31b_arm64.s" diff --git a/field/koalabear/vector_arm64.go b/field/koalabear/vector_arm64.go index 733319e548..8a58770797 100644 --- a/field/koalabear/vector_arm64.go +++ b/field/koalabear/vector_arm64.go @@ -7,10 +7,30 @@ package koalabear +//go:noescape +func addVec(res, a, b *Element, n uint64) + // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Add: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + + const blockSize = 1 + addVec(&(*vector)[0], &a[0], &b[0], n/blockSize) + for i := 0; i < len(*vector); i++ { + (*vector)[i][0] %= q + } + // if n % blockSize != 0 { + // // call addVecGeneric on the rest + // start := n - n % blockSize + // addVecGeneric((*vector)[start:], a[start:], b[start:]) + // } } // Sub subtracts two vectors element-wise and stores the result in self. diff --git a/go.mod b/go.mod index e52cac12d6..f3e814ffcb 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/bits-and-blooms/bitset v1.14.2 - github.com/consensys/bavard v0.1.24 + github.com/consensys/bavard v0.0.0 github.com/leanovate/gopter v0.2.11 github.com/mmcloughlin/addchain v0.4.0 github.com/spf13/cobra v1.8.1 @@ -15,6 +15,8 @@ require ( gopkg.in/yaml.v2 v2.4.0 ) +replace "github.com/consensys/bavard" => ../bavard + require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect From 93c01c4e90d8840e4d74b365296df80cd607d6f9 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 10 Dec 2024 20:21:22 -0600 Subject: [PATCH 60/74] feat: adds f31 neon add --- field/asm/element_31b_arm64.s | 29 ++++++------ field/babybear/element_arm64.s | 2 +- field/babybear/vector_arm64.go | 28 +++++++----- field/generator/asm/arm64/element_vec_F31.go | 45 +++++++++---------- .../templates/element/vector_ops_asm.go | 31 ++++++++----- field/koalabear/element_arm64.s | 2 +- field/koalabear/vector_arm64.go | 28 +++++++----- 7 files changed, 94 insertions(+), 71 deletions(-) diff --git a/field/asm/element_31b_arm64.s b/field/asm/element_31b_arm64.s index 9906479334..1e5357c49c 100644 --- a/field/asm/element_31b_arm64.s +++ b/field/asm/element_31b_arm64.s @@ -3,22 +3,23 @@ #include "funcdata.h" #include "go_asm.h" -// addVec(res, a, b *Element, n uint64) -TEXT ·addVec(SB), NOFRAME|NOSPLIT, $0-32 - MOVD $const_q, R6 - LDP res+0(FP), (R0, R1) - LDP b+16(FP), (R2, R3) +// addVec(qq *uint32, res, a, b *Element, n uint64) +TEXT ·addVec(SB), NOFRAME|NOSPLIT, $0-40 + LDP qq+0(FP), (R1, R0) + LDP a+16(FP), (R2, R3) + MOVD n+32(FP), R4 + VLD1 0(R1), [V2.S4] // broadcast q into V2.S4 loop1: - CBZ R3, done2 - MOVWU.P 4(R1), R5 - MOVWU.P 4(R2), R4 - ADD R5, R4, R4 - SUBS R6, R4, R7 - CSEL CS, R7, R4, R7 - MOVWU.P R4, 4(R0) - SUB $1, R3, R3 - JMP loop1 + CBZ R4, done2 + VLD1.P 16(R2), [V0.S4] + VLD1.P 16(R3), [V1.S4] + VADD V0.S4, V1.S4, V1.S4 // b = a + b + VSUB V2.S4, V1.S4, V3.S4 // t = q - b + VUMIN V3.S4, V1.S4, V1.S4 // b = min(t, b) + VST1.P [V1.S4], 16(R0) // res = b + SUB $1, R4, R4 + JMP loop1 done2: RET diff --git a/field/babybear/element_arm64.s b/field/babybear/element_arm64.s index 6ca9c99f0d..b777a9e0bc 100644 --- a/field/babybear/element_arm64.s +++ b/field/babybear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 8072276761764543562 +// We include the hash to force the Go compiler to recompile: 6520165234218162435 #include "../asm/element_31b_arm64.s" diff --git a/field/babybear/vector_arm64.go b/field/babybear/vector_arm64.go index d3c36b1039..2f4d37ac30 100644 --- a/field/babybear/vector_arm64.go +++ b/field/babybear/vector_arm64.go @@ -7,8 +7,19 @@ package babybear +// qLane is a vector with all elements set to q +// TODO figure out why the arm64 assembly to broadcast a scalar is not working +var qLane [4]uint32 + +func init() { + qLane[0] = q + qLane[1] = q + qLane[2] = q + qLane[3] = q +} + //go:noescape -func addVec(res, a, b *Element, n uint64) +func addVec(qq *uint32, res, a, b *Element, n uint64) // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. @@ -21,16 +32,13 @@ func (vector *Vector) Add(a, b Vector) { return } - const blockSize = 1 - addVec(&(*vector)[0], &a[0], &b[0], n/blockSize) - for i := 0; i < len(*vector); i++ { - (*vector)[i][0] %= q + const blockSize = 4 + addVec(&qLane[0], &(*vector)[0], &a[0], &b[0], n/blockSize) + if n%blockSize != 0 { + // call addVecGeneric on the rest + start := n - n%blockSize + addVecGeneric((*vector)[start:], a[start:], b[start:]) } - // if n % blockSize != 0 { - // // call addVecGeneric on the rest - // start := n - n % blockSize - // addVecGeneric((*vector)[start:], a[start:], b[start:]) - // } } // Sub subtracts two vectors element-wise and stores the result in self. diff --git a/field/generator/asm/arm64/element_vec_F31.go b/field/generator/asm/arm64/element_vec_F31.go index 76b8c4b4e9..934b84df98 100644 --- a/field/generator/asm/arm64/element_vec_F31.go +++ b/field/generator/asm/arm64/element_vec_F31.go @@ -1,51 +1,48 @@ package arm64 +import "github.com/consensys/bavard/arm64" + func (f *FFArm64) generateAddVecF31() { - f.Comment("addVec(res, a, b *Element, n uint64)") - registers := f.FnHeader("addVec", 0, 32) + f.Comment("addVec(qq *uint32, res, a, b *Element, n uint64)") + registers := f.FnHeader("addVec", 0, 40) defer f.AssertCleanStack(0, 0) // registers resPtr := registers.Pop() + qqPtr := registers.Pop() aPtr := registers.Pop() bPtr := registers.Pop() n := registers.Pop() - b := registers.Pop() - a := registers.Pop() - q := registers.Pop() - t := registers.Pop() - - f.MOVD("$const_q", q) + a := arm64.V0.S4() + b := arm64.V1.S4() + q := arm64.V2.S4() + t := arm64.V3.S4() // labels loop := f.NewLabel("loop") done := f.NewLabel("done") // load arguments - f.LDP("res+0(FP)", resPtr, aPtr) - f.LDP("b+16(FP)", bPtr, n) + f.LDP("qq+0(FP)", qqPtr, resPtr) + f.LDP("a+16(FP)", aPtr, bPtr) + f.MOVD("n+32(FP)", n) + + f.VLD1(0, qqPtr, q, "broadcast q into "+string(q)) f.LABEL(loop) f.CBZ(n, done) - // load a - f.MOVWUP_Load(4, aPtr, a) - // load b - f.MOVWUP_Load(4, bPtr, b) - - // res = a + b - f.ADD(a, b, b) - - // t = res - q - f.SUBS(q, b, t) + const offset = 4 * 4 // we process 4 uint32 at a time - // t = min(t, res) - f.CSEL("CS", t, b, t) + f.VLD1_P(offset, aPtr, a) + f.VLD1_P(offset, bPtr, b) - // res = t - f.MOVWUP_Store(b, resPtr, 4) + f.VADD(a, b, b, "b = a + b") + f.VSUB(q, b, t, "t = q - b") + f.VUMIN(t, b, b, "b = min(t, b)") + f.VST1_P(b, resPtr, offset, "res = b") // decrement n f.SUB(1, n, n) diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go index 9acf1b8075..5153c125d0 100644 --- a/field/generator/internal/templates/element/vector_ops_asm.go +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -147,8 +147,20 @@ func mulVec(res, a, b *{{.ElementName}}, n uint64, qInvNeg uint64) const VectorOpsArm64 = VectorOpsPureGo const VectorOpsArm64F31 = ` + + +// qLane is a vector with all elements set to q +// TODO figure out why the arm64 assembly to broadcast a scalar is not working +var qLane [4]uint32 +func init() { + qLane[0] = q + qLane[1] = q + qLane[2] = q + qLane[3] = q +} + //go:noescape -func addVec(res, a, b *{{.ElementName}}, n uint64) +func addVec(qq *uint32, res, a, b *{{.ElementName}}, n uint64) // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. @@ -161,16 +173,13 @@ func (vector *Vector) Add(a, b Vector) { return } - const blockSize = 1 - addVec(&(*vector)[0], &a[0], &b[0], n/blockSize) - for i := 0; i < len(*vector); i++ { - (*vector)[i][0] %= q - } - // if n % blockSize != 0 { - // // call addVecGeneric on the rest - // start := n - n % blockSize - // addVecGeneric((*vector)[start:], a[start:], b[start:]) - // } + const blockSize = 4 + addVec(&qLane[0], &(*vector)[0], &a[0], &b[0], n/blockSize) + if n % blockSize != 0 { + // call addVecGeneric on the rest + start := n - n % blockSize + addVecGeneric((*vector)[start:], a[start:], b[start:]) + } } // Sub subtracts two vectors element-wise and stores the result in self. diff --git a/field/koalabear/element_arm64.s b/field/koalabear/element_arm64.s index 6ca9c99f0d..b777a9e0bc 100644 --- a/field/koalabear/element_arm64.s +++ b/field/koalabear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 8072276761764543562 +// We include the hash to force the Go compiler to recompile: 6520165234218162435 #include "../asm/element_31b_arm64.s" diff --git a/field/koalabear/vector_arm64.go b/field/koalabear/vector_arm64.go index 8a58770797..ff0f7837cd 100644 --- a/field/koalabear/vector_arm64.go +++ b/field/koalabear/vector_arm64.go @@ -7,8 +7,19 @@ package koalabear +// qLane is a vector with all elements set to q +// TODO figure out why the arm64 assembly to broadcast a scalar is not working +var qLane [4]uint32 + +func init() { + qLane[0] = q + qLane[1] = q + qLane[2] = q + qLane[3] = q +} + //go:noescape -func addVec(res, a, b *Element, n uint64) +func addVec(qq *uint32, res, a, b *Element, n uint64) // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. @@ -21,16 +32,13 @@ func (vector *Vector) Add(a, b Vector) { return } - const blockSize = 1 - addVec(&(*vector)[0], &a[0], &b[0], n/blockSize) - for i := 0; i < len(*vector); i++ { - (*vector)[i][0] %= q + const blockSize = 4 + addVec(&qLane[0], &(*vector)[0], &a[0], &b[0], n/blockSize) + if n%blockSize != 0 { + // call addVecGeneric on the rest + start := n - n%blockSize + addVecGeneric((*vector)[start:], a[start:], b[start:]) } - // if n % blockSize != 0 { - // // call addVecGeneric on the rest - // start := n - n % blockSize - // addVecGeneric((*vector)[start:], a[start:], b[start:]) - // } } // Sub subtracts two vectors element-wise and stores the result in self. From daa66420484037d298c1140e5fe0cb5514981367 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 10 Dec 2024 20:26:17 -0600 Subject: [PATCH 61/74] feat: adds f31 neon sub --- field/asm/element_31b_arm64.s | 21 ++++++++ field/babybear/element_arm64.s | 2 +- field/babybear/vector_arm64.go | 21 +++++++- field/generator/asm/arm64/build.go | 2 +- field/generator/asm/arm64/element_vec_F31.go | 51 +++++++++++++++++++ .../templates/element/vector_ops_asm.go | 21 +++++++- field/koalabear/element_arm64.s | 2 +- field/koalabear/vector_arm64.go | 21 +++++++- 8 files changed, 132 insertions(+), 9 deletions(-) diff --git a/field/asm/element_31b_arm64.s b/field/asm/element_31b_arm64.s index 1e5357c49c..6777185b84 100644 --- a/field/asm/element_31b_arm64.s +++ b/field/asm/element_31b_arm64.s @@ -23,3 +23,24 @@ loop1: done2: RET + +// subVec(qq *uint32, res, a, b *Element, n uint64) +TEXT ·subVec(SB), NOFRAME|NOSPLIT, $0-40 + LDP qLane+0(FP), (R1, R0) + LDP a+16(FP), (R2, R3) + MOVD n+32(FP), R4 + VLD1 0(R1), [V2.S4] // broadcast q into V2.S4 + +loop3: + CBZ R4, done4 + VLD1.P 16(R2), [V0.S4] + VLD1.P 16(R3), [V1.S4] + VSUB V1.S4, V0.S4, V1.S4 // b = a - b + VADD V1.S4, V2.S4, V3.S4 // t = b + q + VUMIN V3.S4, V1.S4, V1.S4 // b = min(t, b) + VST1.P [V1.S4], 16(R0) // res = b + SUB $1, R4, R4 + JMP loop3 + +done4: + RET diff --git a/field/babybear/element_arm64.s b/field/babybear/element_arm64.s index b777a9e0bc..143df0bfc9 100644 --- a/field/babybear/element_arm64.s +++ b/field/babybear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 6520165234218162435 +// We include the hash to force the Go compiler to recompile: 8339322716531241176 #include "../asm/element_31b_arm64.s" diff --git a/field/babybear/vector_arm64.go b/field/babybear/vector_arm64.go index 2f4d37ac30..31e6477999 100644 --- a/field/babybear/vector_arm64.go +++ b/field/babybear/vector_arm64.go @@ -19,7 +19,10 @@ func init() { } //go:noescape -func addVec(qq *uint32, res, a, b *Element, n uint64) +func addVec(qLane *uint32, res, a, b *Element, n uint64) + +//go:noescape +func subVec(qLane *uint32, res, a, b *Element, n uint64) // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. @@ -44,7 +47,21 @@ func (vector *Vector) Add(a, b Vector) { // Sub subtracts two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + + const blockSize = 4 + subVec(&qLane[0], &(*vector)[0], &a[0], &b[0], n/blockSize) + if n%blockSize != 0 { + // call subVecGeneric on the rest + start := n - n%blockSize + subVecGeneric((*vector)[start:], a[start:], b[start:]) + } } // ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index ce2759044c..21651bbc7c 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -232,7 +232,7 @@ func GenerateF31ASM(f *FFArm64, hasVector bool) error { } f.generateAddVecF31() - // f.generateSubVecF31() + f.generateSubVecF31() // f.generateSumVecF31() // f.generateMulVecF31() // f.generateScalarMulVecF31() diff --git a/field/generator/asm/arm64/element_vec_F31.go b/field/generator/asm/arm64/element_vec_F31.go index 934b84df98..182e1604cc 100644 --- a/field/generator/asm/arm64/element_vec_F31.go +++ b/field/generator/asm/arm64/element_vec_F31.go @@ -52,3 +52,54 @@ func (f *FFArm64) generateAddVecF31() { f.RET() } + +func (f *FFArm64) generateSubVecF31() { + f.Comment("subVec(qq *uint32, res, a, b *Element, n uint64)") + registers := f.FnHeader("subVec", 0, 40) + defer f.AssertCleanStack(0, 0) + + // registers + resPtr := registers.Pop() + qPtr := registers.Pop() + aPtr := registers.Pop() + bPtr := registers.Pop() + n := registers.Pop() + + a := arm64.V0.S4() + b := arm64.V1.S4() + q := arm64.V2.S4() + t := arm64.V3.S4() + + // labels + loop := f.NewLabel("loop") + done := f.NewLabel("done") + + // load arguments + f.LDP("qLane+0(FP)", qPtr, resPtr) + f.LDP("a+16(FP)", aPtr, bPtr) + f.MOVD("n+32(FP)", n) + + f.VLD1(0, qPtr, q, "broadcast q into "+string(q)) + + f.LABEL(loop) + + f.CBZ(n, done) + + const offset = 4 * 4 // we process 4 uint32 at a time + + f.VLD1_P(offset, aPtr, a) + f.VLD1_P(offset, bPtr, b) + + f.VSUB(b, a, b, "b = a - b") + f.VADD(b, q, t, "t = b + q") + f.VUMIN(t, b, b, "b = min(t, b)") + f.VST1_P(b, resPtr, offset, "res = b") + + // decrement n + f.SUB(1, n, n) + f.JMP(loop) + + f.LABEL(done) + f.RET() + +} diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go index 5153c125d0..f8689305fa 100644 --- a/field/generator/internal/templates/element/vector_ops_asm.go +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -160,7 +160,10 @@ func init() { } //go:noescape -func addVec(qq *uint32, res, a, b *{{.ElementName}}, n uint64) +func addVec(qLane *uint32, res, a, b *{{.ElementName}}, n uint64) + +//go:noescape +func subVec(qLane *uint32, res, a, b *{{.ElementName}}, n uint64) // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. @@ -185,7 +188,21 @@ func (vector *Vector) Add(a, b Vector) { // Sub subtracts two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + + const blockSize = 4 + subVec(&qLane[0], &(*vector)[0], &a[0], &b[0], n/blockSize) + if n % blockSize != 0 { + // call subVecGeneric on the rest + start := n - n % blockSize + subVecGeneric((*vector)[start:], a[start:], b[start:]) + } } // ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. diff --git a/field/koalabear/element_arm64.s b/field/koalabear/element_arm64.s index b777a9e0bc..143df0bfc9 100644 --- a/field/koalabear/element_arm64.s +++ b/field/koalabear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 6520165234218162435 +// We include the hash to force the Go compiler to recompile: 8339322716531241176 #include "../asm/element_31b_arm64.s" diff --git a/field/koalabear/vector_arm64.go b/field/koalabear/vector_arm64.go index ff0f7837cd..b539cd9a7b 100644 --- a/field/koalabear/vector_arm64.go +++ b/field/koalabear/vector_arm64.go @@ -19,7 +19,10 @@ func init() { } //go:noescape -func addVec(qq *uint32, res, a, b *Element, n uint64) +func addVec(qLane *uint32, res, a, b *Element, n uint64) + +//go:noescape +func subVec(qLane *uint32, res, a, b *Element, n uint64) // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. @@ -44,7 +47,21 @@ func (vector *Vector) Add(a, b Vector) { // Sub subtracts two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Sub: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + + const blockSize = 4 + subVec(&qLane[0], &(*vector)[0], &a[0], &b[0], n/blockSize) + if n%blockSize != 0 { + // call subVecGeneric on the rest + start := n - n%blockSize + subVecGeneric((*vector)[start:], a[start:], b[start:]) + } } // ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. From 48206ac4ca0b5fcf7278e659dede522c177a070e Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 10 Dec 2024 21:36:26 -0600 Subject: [PATCH 62/74] feat: add neon f31 sum --- field/asm/element_31b_arm64.s | 30 ++++++++ field/babybear/element_arm64.s | 2 +- field/babybear/vector_amd64.go | 6 +- field/babybear/vector_arm64.go | 24 ++++++- field/generator/asm/arm64/build.go | 2 +- field/generator/asm/arm64/element_vec_F31.go | 72 +++++++++++++++++++ .../templates/element/vector_ops_asm.go | 31 ++++++-- field/koalabear/element_arm64.s | 2 +- field/koalabear/vector_amd64.go | 6 +- field/koalabear/vector_arm64.go | 24 ++++++- 10 files changed, 181 insertions(+), 18 deletions(-) diff --git a/field/asm/element_31b_arm64.s b/field/asm/element_31b_arm64.s index 6777185b84..74b6c9b6f9 100644 --- a/field/asm/element_31b_arm64.s +++ b/field/asm/element_31b_arm64.s @@ -44,3 +44,33 @@ loop3: done4: RET + +// sumVec(t *uint64, a *[]uint32, n uint64) res = sum(a[0...n]) +TEXT ·sumVec(SB), NOFRAME|NOSPLIT, $0-24 + VMOVQ $0, $0, V4 + VMOVQ $0, $0, V5 + VMOVQ $0, $0, V6 + VMOVQ $0, $0, V7 + LDP t+0(FP), (R1, R0) + MOVD n+16(FP), R2 + +loop5: + CBZ R2, done6 + VLD2.P 16(R0), [V0.S2, V1.S2] + VLD2.P 16(R0), [V2.S2, V3.S2] + VUSHLL $0, V0.S2, V0.D2 // convert to 64 bits + VUSHLL $0, V1.S2, V1.D2 // convert to 64 bits + VADD V0.D2, V4.D2, V4.D2 // acc1 += a1 + VADD V1.D2, V5.D2, V5.D2 // acc2 += a2 + VUSHLL $0, V2.S2, V2.D2 // convert to 64 bits + VUSHLL $0, V3.S2, V3.D2 // convert to 64 bits + VADD V2.D2, V6.D2, V6.D2 // acc3 += a3 + VADD V3.D2, V7.D2, V7.D2 // acc4 += a4 + SUB $1, R2, R2 + JMP loop5 + +done6: + VADD V4.D2, V6.D2, V4.D2 // acc1 += acc3 + VADD V5.D2, V7.D2, V5.D2 // acc2 += acc4 + VST2.P [V4.D2, V5.D2], 0(R1) // store acc1 and acc2 + RET diff --git a/field/babybear/element_arm64.s b/field/babybear/element_arm64.s index 143df0bfc9..e9464dcb8b 100644 --- a/field/babybear/element_arm64.s +++ b/field/babybear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 8339322716531241176 +// We include the hash to force the Go compiler to recompile: 13481038126710208331 #include "../asm/element_31b_arm64.s" diff --git a/field/babybear/vector_amd64.go b/field/babybear/vector_amd64.go index 2ace506476..34573016f4 100644 --- a/field/babybear/vector_amd64.go +++ b/field/babybear/vector_amd64.go @@ -116,10 +116,9 @@ func (vector *Vector) Sum() (res Element) { var t [8]uint64 // stores the accumulators (not reduced mod q) sumVec(&t[0], &(*vector)[0], n/blockSize) // we reduce the accumulators mod q and add to res - var v Element for i := 0; i < 8; i++ { t[i] %= q - v[0] = uint32(t[i]) + v[0] = uint32(t[i] % q) res.Add(&res, &v) } if n%blockSize != 0 { @@ -153,8 +152,7 @@ func (vector *Vector) InnerProduct(other Vector) (res Element) { // we reduce the accumulators mod q and add to res var v Element for i := 0; i < 8; i++ { - t[i] %= q - v[0] = uint32(t[i]) + v[0] = uint32(t[i] % q) res.Add(&res, &v) } if n%blockSize != 0 { diff --git a/field/babybear/vector_arm64.go b/field/babybear/vector_arm64.go index 31e6477999..88e175dcdd 100644 --- a/field/babybear/vector_arm64.go +++ b/field/babybear/vector_arm64.go @@ -24,6 +24,9 @@ func addVec(qLane *uint32, res, a, b *Element, n uint64) //go:noescape func subVec(qLane *uint32, res, a, b *Element, n uint64) +//go:noescape +func sumVec(t *uint64, a *Element, n uint64) + // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -72,7 +75,26 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { // Sum computes the sum of all elements in the vector. func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) + n := uint64(len(*vector)) + if n == 0 { + return + } + + const blockSize = 8 + var t [4]uint64 // stores the accumulators (not reduced mod q) + sumVec(&t[0], &(*vector)[0], n/blockSize) + // we reduce the accumulators mod q and add to res + var v Element + for i := 0; i < 4; i++ { + v[0] = uint32(t[i] % q) + res.Add(&res, &v) + } + if n%blockSize != 0 { + // call sumVecGeneric on the rest + start := n - n%blockSize + sumVecGeneric(&res, (*vector)[start:]) + } + return } diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index 21651bbc7c..0c7fdd77bb 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -233,7 +233,7 @@ func GenerateF31ASM(f *FFArm64, hasVector bool) error { f.generateAddVecF31() f.generateSubVecF31() - // f.generateSumVecF31() + f.generateSumVecF31() // f.generateMulVecF31() // f.generateScalarMulVecF31() // f.generateInnerProdVecF31() diff --git a/field/generator/asm/arm64/element_vec_F31.go b/field/generator/asm/arm64/element_vec_F31.go index 182e1604cc..43b208ca0c 100644 --- a/field/generator/asm/arm64/element_vec_F31.go +++ b/field/generator/asm/arm64/element_vec_F31.go @@ -103,3 +103,75 @@ func (f *FFArm64) generateSubVecF31() { f.RET() } + +func (f *FFArm64) generateSumVecF31() { + f.Comment("sumVec(t *uint64, a *[]uint32, n uint64) res = sum(a[0...n])") + registers := f.FnHeader("sumVec", 0, 3*8) + defer f.AssertCleanStack(0, 0) + + // registers + aPtr := registers.Pop() + tPtr := registers.Pop() + n := registers.Pop() + + a1 := arm64.V0 + a2 := arm64.V1 + a3 := arm64.V2 + a4 := arm64.V3 + acc1 := arm64.V4 + acc2 := arm64.V5 + acc3 := arm64.V6 + acc4 := arm64.V7 + + // zero out accumulators + f.VMOVQ_cst(0, 0, acc1) + f.VMOVQ_cst(0, 0, acc2) + f.VMOVQ_cst(0, 0, acc3) + f.VMOVQ_cst(0, 0, acc4) + + acc1 = arm64.V4.D2() + acc2 = arm64.V5.D2() + acc3 = arm64.V6.D2() + acc4 = arm64.V7.D2() + + // labels + loop := f.NewLabel("loop") + done := f.NewLabel("done") + + // load arguments + f.LDP("t+0(FP)", tPtr, aPtr) + f.MOVD("n+16(FP)", n) + + f.LABEL(loop) + + f.CBZ(n, done) + + const offset = 8 * 4 // we process 4 uint32 at a time + + f.VLD2_P(offset/2, aPtr, a1.S2(), a2.S2()) // load 2*2 uint32 + f.VLD2_P(offset/2, aPtr, a3.S2(), a4.S2()) // load 2*2 uint32 + + f.VUSHLL(0, a1.S2(), a1.D2(), "convert to 64 bits") + f.VUSHLL(0, a2.S2(), a2.D2(), "convert to 64 bits") + f.VADD(a1.D2(), acc1, acc1, "acc1 += a1") + f.VADD(a2.D2(), acc2, acc2, "acc2 += a2") + + f.VUSHLL(0, a3.S2(), a3.D2(), "convert to 64 bits") + f.VUSHLL(0, a4.S2(), a4.D2(), "convert to 64 bits") + f.VADD(a3.D2(), acc3, acc3, "acc3 += a3") + f.VADD(a4.D2(), acc4, acc4, "acc4 += a4") + + // decrement n + f.SUB(1, n, n) + f.JMP(loop) + + f.LABEL(done) + + f.VADD(acc1, acc3, acc1, "acc1 += acc3") + f.VADD(acc2, acc4, acc2, "acc2 += acc4") + + f.VST2_P(acc1, acc2, tPtr, 0, "store acc1 and acc2") + + f.RET() + +} diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go index f8689305fa..a6c4c99a84 100644 --- a/field/generator/internal/templates/element/vector_ops_asm.go +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -165,6 +165,10 @@ func addVec(qLane *uint32, res, a, b *{{.ElementName}}, n uint64) //go:noescape func subVec(qLane *uint32, res, a, b *{{.ElementName}}, n uint64) +//go:noescape +func sumVec(t *uint64, a *{{.ElementName}}, n uint64) + + // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -213,7 +217,26 @@ func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { // Sum computes the sum of all elements in the vector. func (vector *Vector) Sum() (res {{.ElementName}}) { - sumVecGeneric(&res, *vector) + n := uint64(len(*vector)) + if n == 0 { + return + } + + const blockSize = 8 + var t [4]uint64 // stores the accumulators (not reduced mod q) + sumVec(&t[0], &(*vector)[0], n/blockSize) + // we reduce the accumulators mod q and add to res + var v {{.ElementName}} + for i := 0; i < 4; i++ { + v[0] = uint32(t[i] % q) + res.Add(&res, &v) + } + if n % blockSize != 0 { + // call sumVecGeneric on the rest + start := n - n % blockSize + sumVecGeneric(&res, (*vector)[start:]) + } + return } @@ -342,10 +365,9 @@ func (vector *Vector) Sum() (res {{.ElementName}}) { var t [8]uint64 // stores the accumulators (not reduced mod q) sumVec(&t[0], &(*vector)[0], n/blockSize) // we reduce the accumulators mod q and add to res - var v {{.ElementName}} for i := 0; i < 8; i++ { t[i] %= q - v[0] = uint32(t[i]) + v[0] = uint32(t[i] % q) res.Add(&res, &v) } if n % blockSize != 0 { @@ -379,8 +401,7 @@ func (vector *Vector) InnerProduct(other Vector) (res {{.ElementName}}) { // we reduce the accumulators mod q and add to res var v {{.ElementName}} for i := 0; i < 8; i++ { - t[i] %= q - v[0] = uint32(t[i]) + v[0] = uint32(t[i] % q) res.Add(&res, &v) } if n % blockSize != 0 { diff --git a/field/koalabear/element_arm64.s b/field/koalabear/element_arm64.s index 143df0bfc9..e9464dcb8b 100644 --- a/field/koalabear/element_arm64.s +++ b/field/koalabear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 8339322716531241176 +// We include the hash to force the Go compiler to recompile: 13481038126710208331 #include "../asm/element_31b_arm64.s" diff --git a/field/koalabear/vector_amd64.go b/field/koalabear/vector_amd64.go index cc80c68d04..359ed5eb7d 100644 --- a/field/koalabear/vector_amd64.go +++ b/field/koalabear/vector_amd64.go @@ -116,10 +116,9 @@ func (vector *Vector) Sum() (res Element) { var t [8]uint64 // stores the accumulators (not reduced mod q) sumVec(&t[0], &(*vector)[0], n/blockSize) // we reduce the accumulators mod q and add to res - var v Element for i := 0; i < 8; i++ { t[i] %= q - v[0] = uint32(t[i]) + v[0] = uint32(t[i] % q) res.Add(&res, &v) } if n%blockSize != 0 { @@ -153,8 +152,7 @@ func (vector *Vector) InnerProduct(other Vector) (res Element) { // we reduce the accumulators mod q and add to res var v Element for i := 0; i < 8; i++ { - t[i] %= q - v[0] = uint32(t[i]) + v[0] = uint32(t[i] % q) res.Add(&res, &v) } if n%blockSize != 0 { diff --git a/field/koalabear/vector_arm64.go b/field/koalabear/vector_arm64.go index b539cd9a7b..2501b365fa 100644 --- a/field/koalabear/vector_arm64.go +++ b/field/koalabear/vector_arm64.go @@ -24,6 +24,9 @@ func addVec(qLane *uint32, res, a, b *Element, n uint64) //go:noescape func subVec(qLane *uint32, res, a, b *Element, n uint64) +//go:noescape +func sumVec(t *uint64, a *Element, n uint64) + // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -72,7 +75,26 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { // Sum computes the sum of all elements in the vector. func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) + n := uint64(len(*vector)) + if n == 0 { + return + } + + const blockSize = 8 + var t [4]uint64 // stores the accumulators (not reduced mod q) + sumVec(&t[0], &(*vector)[0], n/blockSize) + // we reduce the accumulators mod q and add to res + var v Element + for i := 0; i < 4; i++ { + v[0] = uint32(t[i] % q) + res.Add(&res, &v) + } + if n%blockSize != 0 { + // call sumVecGeneric on the rest + start := n - n%blockSize + sumVecGeneric(&res, (*vector)[start:]) + } + return } From 5747fb264c683b5dab6bddefd8a44f23f6248092 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Wed, 11 Dec 2024 10:53:45 -0600 Subject: [PATCH 63/74] feat,perf: faster sum on neon f31 --- field/asm/element_31b_arm64.s | 37 ++++++++++++------- field/babybear/element_arm64.s | 2 +- field/babybear/vector_arm64.go | 2 +- field/generator/asm/arm64/element_vec_F31.go | 32 ++++++++++------ .../templates/element/vector_ops_asm.go | 2 +- field/koalabear/element_arm64.s | 2 +- field/koalabear/vector_arm64.go | 2 +- 7 files changed, 50 insertions(+), 29 deletions(-) diff --git a/field/asm/element_31b_arm64.s b/field/asm/element_31b_arm64.s index 74b6c9b6f9..58fc093cc1 100644 --- a/field/asm/element_31b_arm64.s +++ b/field/asm/element_31b_arm64.s @@ -47,6 +47,7 @@ done4: // sumVec(t *uint64, a *[]uint32, n uint64) res = sum(a[0...n]) TEXT ·sumVec(SB), NOFRAME|NOSPLIT, $0-24 + // zeroing accumulators VMOVQ $0, $0, V4 VMOVQ $0, $0, V5 VMOVQ $0, $0, V6 @@ -55,19 +56,29 @@ TEXT ·sumVec(SB), NOFRAME|NOSPLIT, $0-24 MOVD n+16(FP), R2 loop5: - CBZ R2, done6 - VLD2.P 16(R0), [V0.S2, V1.S2] - VLD2.P 16(R0), [V2.S2, V3.S2] - VUSHLL $0, V0.S2, V0.D2 // convert to 64 bits - VUSHLL $0, V1.S2, V1.D2 // convert to 64 bits - VADD V0.D2, V4.D2, V4.D2 // acc1 += a1 - VADD V1.D2, V5.D2, V5.D2 // acc2 += a2 - VUSHLL $0, V2.S2, V2.D2 // convert to 64 bits - VUSHLL $0, V3.S2, V3.D2 // convert to 64 bits - VADD V2.D2, V6.D2, V6.D2 // acc3 += a3 - VADD V3.D2, V7.D2, V7.D2 // acc4 += a4 - SUB $1, R2, R2 - JMP loop5 + CBZ R2, done6 + + // blockSize is 16 uint32; we load 4 vectors of 4 uint32 at a time + // (4*4)*4 = 64 bytes ~= 1 cache line + // since our values are 31 bits, we can add 2 by 2 these vectors + // we are left with 2 vectors of 4x32 bits values + // that we accumulate in 4*2*64bits accumulators + // the caller will reduce mod q the accumulators. + + VLD2.P 32(R0), [V0.S4, V1.S4] + VADD V0.S4, V1.S4, V0.S4 // a1 += a2 + VLD2.P 32(R0), [V2.S4, V3.S4] + VADD V2.S4, V3.S4, V2.S4 // a3 += a4 + VUSHLL $0, V0.S2, V1.D2 // convert low words to 64 bits + VADD V1.D2, V5.D2, V5.D2 // acc2 += a2 + VUSHLL2 $0, V0.S4, V0.D2 // convert high words to 64 bits + VADD V0.D2, V4.D2, V4.D2 // acc1 += a1 + VUSHLL $0, V2.S2, V3.D2 // convert low words to 64 bits + VADD V3.D2, V7.D2, V7.D2 // acc4 += a4 + VUSHLL2 $0, V2.S4, V2.D2 // convert high words to 64 bits + VADD V2.D2, V6.D2, V6.D2 // acc3 += a3 + SUB $1, R2, R2 + JMP loop5 done6: VADD V4.D2, V6.D2, V4.D2 // acc1 += acc3 diff --git a/field/babybear/element_arm64.s b/field/babybear/element_arm64.s index e9464dcb8b..a48da93dc3 100644 --- a/field/babybear/element_arm64.s +++ b/field/babybear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 13481038126710208331 +// We include the hash to force the Go compiler to recompile: 12345450131470704682 #include "../asm/element_31b_arm64.s" diff --git a/field/babybear/vector_arm64.go b/field/babybear/vector_arm64.go index 88e175dcdd..766c634038 100644 --- a/field/babybear/vector_arm64.go +++ b/field/babybear/vector_arm64.go @@ -80,7 +80,7 @@ func (vector *Vector) Sum() (res Element) { return } - const blockSize = 8 + const blockSize = 16 var t [4]uint64 // stores the accumulators (not reduced mod q) sumVec(&t[0], &(*vector)[0], n/blockSize) // we reduce the accumulators mod q and add to res diff --git a/field/generator/asm/arm64/element_vec_F31.go b/field/generator/asm/arm64/element_vec_F31.go index 43b208ca0c..3f4f9da94e 100644 --- a/field/generator/asm/arm64/element_vec_F31.go +++ b/field/generator/asm/arm64/element_vec_F31.go @@ -123,7 +123,7 @@ func (f *FFArm64) generateSumVecF31() { acc3 := arm64.V6 acc4 := arm64.V7 - // zero out accumulators + f.Comment("zeroing accumulators") f.VMOVQ_cst(0, 0, acc1) f.VMOVQ_cst(0, 0, acc2) f.VMOVQ_cst(0, 0, acc3) @@ -143,23 +143,33 @@ func (f *FFArm64) generateSumVecF31() { f.MOVD("n+16(FP)", n) f.LABEL(loop) - f.CBZ(n, done) - const offset = 8 * 4 // we process 4 uint32 at a time + f.WriteLn(` + // blockSize is 16 uint32; we load 4 vectors of 4 uint32 at a time + // (4*4)*4 = 64 bytes ~= 1 cache line + // since our values are 31 bits, we can add 2 by 2 these vectors + // we are left with 2 vectors of 4x32 bits values + // that we accumulate in 4*2*64bits accumulators + // the caller will reduce mod q the accumulators. + `) - f.VLD2_P(offset/2, aPtr, a1.S2(), a2.S2()) // load 2*2 uint32 - f.VLD2_P(offset/2, aPtr, a3.S2(), a4.S2()) // load 2*2 uint32 + const offset = 8 * 4 + f.VLD2_P(offset, aPtr, a1.S4(), a2.S4()) + f.VADD(a1.S4(), a2.S4(), a1.S4(), "a1 += a2") - f.VUSHLL(0, a1.S2(), a1.D2(), "convert to 64 bits") - f.VUSHLL(0, a2.S2(), a2.D2(), "convert to 64 bits") - f.VADD(a1.D2(), acc1, acc1, "acc1 += a1") + f.VLD2_P(offset, aPtr, a3.S4(), a4.S4()) + f.VADD(a3.S4(), a4.S4(), a3.S4(), "a3 += a4") + + f.VUSHLL(0, a1.S2(), a2.D2(), "convert low words to 64 bits") f.VADD(a2.D2(), acc2, acc2, "acc2 += a2") + f.VUSHLL2(0, a1.S4(), a1.D2(), "convert high words to 64 bits") + f.VADD(a1.D2(), acc1, acc1, "acc1 += a1") - f.VUSHLL(0, a3.S2(), a3.D2(), "convert to 64 bits") - f.VUSHLL(0, a4.S2(), a4.D2(), "convert to 64 bits") - f.VADD(a3.D2(), acc3, acc3, "acc3 += a3") + f.VUSHLL(0, a3.S2(), a4.D2(), "convert low words to 64 bits") f.VADD(a4.D2(), acc4, acc4, "acc4 += a4") + f.VUSHLL2(0, a3.S4(), a3.D2(), "convert high words to 64 bits") + f.VADD(a3.D2(), acc3, acc3, "acc3 += a3") // decrement n f.SUB(1, n, n) diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go index a6c4c99a84..9ff3b30761 100644 --- a/field/generator/internal/templates/element/vector_ops_asm.go +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -222,7 +222,7 @@ func (vector *Vector) Sum() (res {{.ElementName}}) { return } - const blockSize = 8 + const blockSize = 16 var t [4]uint64 // stores the accumulators (not reduced mod q) sumVec(&t[0], &(*vector)[0], n/blockSize) // we reduce the accumulators mod q and add to res diff --git a/field/koalabear/element_arm64.s b/field/koalabear/element_arm64.s index e9464dcb8b..a48da93dc3 100644 --- a/field/koalabear/element_arm64.s +++ b/field/koalabear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 13481038126710208331 +// We include the hash to force the Go compiler to recompile: 12345450131470704682 #include "../asm/element_31b_arm64.s" diff --git a/field/koalabear/vector_arm64.go b/field/koalabear/vector_arm64.go index 2501b365fa..083e47cd99 100644 --- a/field/koalabear/vector_arm64.go +++ b/field/koalabear/vector_arm64.go @@ -80,7 +80,7 @@ func (vector *Vector) Sum() (res Element) { return } - const blockSize = 8 + const blockSize = 16 var t [4]uint64 // stores the accumulators (not reduced mod q) sumVec(&t[0], &(*vector)[0], n/blockSize) // we reduce the accumulators mod q and add to res From e76768fc9611147956a255a1c00617ca3811cd22 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Wed, 11 Dec 2024 14:29:30 -0600 Subject: [PATCH 64/74] perf: move q from const in vector broadcast --- field/asm/element_31b_arm64.s | 22 +++++++++---------- field/babybear/element_arm64.s | 2 +- field/babybear/vector_arm64.go | 4 ++-- field/generator/asm/arm64/element_vec_F31.go | 17 +++++++------- .../templates/element/vector_ops_asm.go | 4 ++-- field/koalabear/element_arm64.s | 2 +- field/koalabear/vector_arm64.go | 4 ++-- 7 files changed, 28 insertions(+), 27 deletions(-) diff --git a/field/asm/element_31b_arm64.s b/field/asm/element_31b_arm64.s index 58fc093cc1..1ce921a72e 100644 --- a/field/asm/element_31b_arm64.s +++ b/field/asm/element_31b_arm64.s @@ -3,22 +3,22 @@ #include "funcdata.h" #include "go_asm.h" -// addVec(qq *uint32, res, a, b *Element, n uint64) +// addVec(res, a, b *Element, n uint64) TEXT ·addVec(SB), NOFRAME|NOSPLIT, $0-40 - LDP qq+0(FP), (R1, R0) - LDP a+16(FP), (R2, R3) - MOVD n+32(FP), R4 - VLD1 0(R1), [V2.S4] // broadcast q into V2.S4 + LDP res+0(FP), (R0, R1) + LDP b+16(FP), (R2, R3) + VMOVS $const_q, V3 + VDUP V3.S[0], V3.S4 // broadcast q into V3 loop1: - CBZ R4, done2 - VLD1.P 16(R2), [V0.S4] - VLD1.P 16(R3), [V1.S4] + CBZ R3, done2 + VLD1.P 16(R1), [V0.S4] + VLD1.P 16(R2), [V1.S4] VADD V0.S4, V1.S4, V1.S4 // b = a + b - VSUB V2.S4, V1.S4, V3.S4 // t = q - b - VUMIN V3.S4, V1.S4, V1.S4 // b = min(t, b) + VSUB V3.S4, V1.S4, V2.S4 // t = q - b + VUMIN V2.S4, V1.S4, V1.S4 // b = min(t, b) VST1.P [V1.S4], 16(R0) // res = b - SUB $1, R4, R4 + SUB $1, R3, R3 JMP loop1 done2: diff --git a/field/babybear/element_arm64.s b/field/babybear/element_arm64.s index a48da93dc3..746a2b753a 100644 --- a/field/babybear/element_arm64.s +++ b/field/babybear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 12345450131470704682 +// We include the hash to force the Go compiler to recompile: 6685189974678439935 #include "../asm/element_31b_arm64.s" diff --git a/field/babybear/vector_arm64.go b/field/babybear/vector_arm64.go index 766c634038..c70a24de06 100644 --- a/field/babybear/vector_arm64.go +++ b/field/babybear/vector_arm64.go @@ -19,7 +19,7 @@ func init() { } //go:noescape -func addVec(qLane *uint32, res, a, b *Element, n uint64) +func addVec(res, a, b *Element, n uint64) //go:noescape func subVec(qLane *uint32, res, a, b *Element, n uint64) @@ -39,7 +39,7 @@ func (vector *Vector) Add(a, b Vector) { } const blockSize = 4 - addVec(&qLane[0], &(*vector)[0], &a[0], &b[0], n/blockSize) + addVec(&(*vector)[0], &a[0], &b[0], n/blockSize) if n%blockSize != 0 { // call addVecGeneric on the rest start := n - n%blockSize diff --git a/field/generator/asm/arm64/element_vec_F31.go b/field/generator/asm/arm64/element_vec_F31.go index 3f4f9da94e..41ae8bd3bb 100644 --- a/field/generator/asm/arm64/element_vec_F31.go +++ b/field/generator/asm/arm64/element_vec_F31.go @@ -3,32 +3,33 @@ package arm64 import "github.com/consensys/bavard/arm64" func (f *FFArm64) generateAddVecF31() { - f.Comment("addVec(qq *uint32, res, a, b *Element, n uint64)") + f.Comment("addVec(res, a, b *Element, n uint64)") registers := f.FnHeader("addVec", 0, 40) defer f.AssertCleanStack(0, 0) // registers resPtr := registers.Pop() - qqPtr := registers.Pop() + // qqPtr := registers.Pop() aPtr := registers.Pop() bPtr := registers.Pop() n := registers.Pop() a := arm64.V0.S4() b := arm64.V1.S4() - q := arm64.V2.S4() - t := arm64.V3.S4() + t := arm64.V2.S4() + q := arm64.V3 // labels loop := f.NewLabel("loop") done := f.NewLabel("done") // load arguments - f.LDP("qq+0(FP)", qqPtr, resPtr) - f.LDP("a+16(FP)", aPtr, bPtr) - f.MOVD("n+32(FP)", n) + f.LDP("res+0(FP)", resPtr, aPtr) + f.LDP("b+16(FP)", bPtr, n) - f.VLD1(0, qqPtr, q, "broadcast q into "+string(q)) + f.VMOVS("$const_q", q) + f.VDUP(q.SAt(0), q.S4(), "broadcast q into "+string(q)) + q = q.S4() f.LABEL(loop) diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go index 9ff3b30761..0f95084bcc 100644 --- a/field/generator/internal/templates/element/vector_ops_asm.go +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -160,7 +160,7 @@ func init() { } //go:noescape -func addVec(qLane *uint32, res, a, b *{{.ElementName}}, n uint64) +func addVec(res, a, b *{{.ElementName}}, n uint64) //go:noescape func subVec(qLane *uint32, res, a, b *{{.ElementName}}, n uint64) @@ -181,7 +181,7 @@ func (vector *Vector) Add(a, b Vector) { } const blockSize = 4 - addVec(&qLane[0], &(*vector)[0], &a[0], &b[0], n/blockSize) + addVec(&(*vector)[0], &a[0], &b[0], n/blockSize) if n % blockSize != 0 { // call addVecGeneric on the rest start := n - n % blockSize diff --git a/field/koalabear/element_arm64.s b/field/koalabear/element_arm64.s index a48da93dc3..746a2b753a 100644 --- a/field/koalabear/element_arm64.s +++ b/field/koalabear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 12345450131470704682 +// We include the hash to force the Go compiler to recompile: 6685189974678439935 #include "../asm/element_31b_arm64.s" diff --git a/field/koalabear/vector_arm64.go b/field/koalabear/vector_arm64.go index 083e47cd99..8f0777d5e6 100644 --- a/field/koalabear/vector_arm64.go +++ b/field/koalabear/vector_arm64.go @@ -19,7 +19,7 @@ func init() { } //go:noescape -func addVec(qLane *uint32, res, a, b *Element, n uint64) +func addVec(res, a, b *Element, n uint64) //go:noescape func subVec(qLane *uint32, res, a, b *Element, n uint64) @@ -39,7 +39,7 @@ func (vector *Vector) Add(a, b Vector) { } const blockSize = 4 - addVec(&qLane[0], &(*vector)[0], &a[0], &b[0], n/blockSize) + addVec(&(*vector)[0], &a[0], &b[0], n/blockSize) if n%blockSize != 0 { // call addVecGeneric on the rest start := n - n%blockSize From 18390fdaa210d03f63d31af5c50dcccd17179b74 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Wed, 11 Dec 2024 15:06:28 -0600 Subject: [PATCH 65/74] style: cleaning stuff --- field/asm/element_31b_arm64.s | 68 +++++++++------ field/babybear/element_arm64.s | 2 +- field/babybear/vector_arm64.go | 15 +--- field/generator/asm/arm64/element_vec_F31.go | 86 ++++++++++--------- .../templates/element/vector_ops_asm.go | 16 +--- field/koalabear/element_arm64.s | 2 +- field/koalabear/vector_arm64.go | 15 +--- 7 files changed, 98 insertions(+), 106 deletions(-) diff --git a/field/asm/element_31b_arm64.s b/field/asm/element_31b_arm64.s index 1ce921a72e..1ef1b5778c 100644 --- a/field/asm/element_31b_arm64.s +++ b/field/asm/element_31b_arm64.s @@ -4,45 +4,63 @@ #include "go_asm.h" // addVec(res, a, b *Element, n uint64) -TEXT ·addVec(SB), NOFRAME|NOSPLIT, $0-40 - LDP res+0(FP), (R0, R1) - LDP b+16(FP), (R2, R3) - VMOVS $const_q, V3 - VDUP V3.S[0], V3.S4 // broadcast q into V3 +TEXT ·addVec(SB), NOFRAME|NOSPLIT, $0-32 + LDP res+0(FP), (R0, R1) + LDP b+16(FP), (R2, R3) + +#define a V0 +#define b V1 +#define t V2 +#define q V3 + VMOVS $const_q, q + VDUP q.S[0], q.S4 // broadcast q into q loop1: CBZ R3, done2 - VLD1.P 16(R1), [V0.S4] - VLD1.P 16(R2), [V1.S4] - VADD V0.S4, V1.S4, V1.S4 // b = a + b - VSUB V3.S4, V1.S4, V2.S4 // t = q - b - VUMIN V2.S4, V1.S4, V1.S4 // b = min(t, b) - VST1.P [V1.S4], 16(R0) // res = b + VLD1.P 16(R1), [a.S4] + VLD1.P 16(R2), [b.S4] + VADD a.S4, b.S4, b.S4 // b = a + b + VSUB q.S4, b.S4, t.S4 // t = q - b + VUMIN t.S4, b.S4, b.S4 // b = min(t, b) + VST1.P [b.S4], 16(R0) // res = b SUB $1, R3, R3 JMP loop1 done2: +#undef a +#undef b +#undef t +#undef q RET -// subVec(qq *uint32, res, a, b *Element, n uint64) -TEXT ·subVec(SB), NOFRAME|NOSPLIT, $0-40 - LDP qLane+0(FP), (R1, R0) - LDP a+16(FP), (R2, R3) - MOVD n+32(FP), R4 - VLD1 0(R1), [V2.S4] // broadcast q into V2.S4 +// subVec(res, a, b *Element, n uint64) +TEXT ·subVec(SB), NOFRAME|NOSPLIT, $0-32 + LDP res+0(FP), (R0, R1) + LDP b+16(FP), (R2, R3) + +#define a V0 +#define b V1 +#define t V2 +#define q V3 + VMOVS $const_q, q + VDUP q.S[0], q.S4 // broadcast q into q loop3: - CBZ R4, done4 - VLD1.P 16(R2), [V0.S4] - VLD1.P 16(R3), [V1.S4] - VSUB V1.S4, V0.S4, V1.S4 // b = a - b - VADD V1.S4, V2.S4, V3.S4 // t = b + q - VUMIN V3.S4, V1.S4, V1.S4 // b = min(t, b) - VST1.P [V1.S4], 16(R0) // res = b - SUB $1, R4, R4 + CBZ R3, done4 + VLD1.P 16(R1), [a.S4] + VLD1.P 16(R2), [b.S4] + VSUB b.S4, a.S4, b.S4 // b = a - b + VADD b.S4, q.S4, t.S4 // t = b + q + VUMIN t.S4, b.S4, b.S4 // b = min(t, b) + VST1.P [b.S4], 16(R0) // res = b + SUB $1, R3, R3 JMP loop3 done4: +#undef a +#undef b +#undef q +#undef t RET // sumVec(t *uint64, a *[]uint32, n uint64) res = sum(a[0...n]) diff --git a/field/babybear/element_arm64.s b/field/babybear/element_arm64.s index 746a2b753a..feecabef0a 100644 --- a/field/babybear/element_arm64.s +++ b/field/babybear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 6685189974678439935 +// We include the hash to force the Go compiler to recompile: 11749081742385060317 #include "../asm/element_31b_arm64.s" diff --git a/field/babybear/vector_arm64.go b/field/babybear/vector_arm64.go index c70a24de06..ec387191ed 100644 --- a/field/babybear/vector_arm64.go +++ b/field/babybear/vector_arm64.go @@ -7,22 +7,11 @@ package babybear -// qLane is a vector with all elements set to q -// TODO figure out why the arm64 assembly to broadcast a scalar is not working -var qLane [4]uint32 - -func init() { - qLane[0] = q - qLane[1] = q - qLane[2] = q - qLane[3] = q -} - //go:noescape func addVec(res, a, b *Element, n uint64) //go:noescape -func subVec(qLane *uint32, res, a, b *Element, n uint64) +func subVec(res, a, b *Element, n uint64) //go:noescape func sumVec(t *uint64, a *Element, n uint64) @@ -59,7 +48,7 @@ func (vector *Vector) Sub(a, b Vector) { } const blockSize = 4 - subVec(&qLane[0], &(*vector)[0], &a[0], &b[0], n/blockSize) + subVec(&(*vector)[0], &a[0], &b[0], n/blockSize) if n%blockSize != 0 { // call subVecGeneric on the rest start := n - n%blockSize diff --git a/field/generator/asm/arm64/element_vec_F31.go b/field/generator/asm/arm64/element_vec_F31.go index 41ae8bd3bb..cd22c5a643 100644 --- a/field/generator/asm/arm64/element_vec_F31.go +++ b/field/generator/asm/arm64/element_vec_F31.go @@ -4,8 +4,9 @@ import "github.com/consensys/bavard/arm64" func (f *FFArm64) generateAddVecF31() { f.Comment("addVec(res, a, b *Element, n uint64)") - registers := f.FnHeader("addVec", 0, 40) + registers := f.FnHeader("addVec", 0, 32) defer f.AssertCleanStack(0, 0) + defer registers.AssertCleanState() // registers resPtr := registers.Pop() @@ -14,11 +15,6 @@ func (f *FFArm64) generateAddVecF31() { bPtr := registers.Pop() n := registers.Pop() - a := arm64.V0.S4() - b := arm64.V1.S4() - t := arm64.V2.S4() - q := arm64.V3 - // labels loop := f.NewLabel("loop") done := f.NewLabel("done") @@ -27,9 +23,13 @@ func (f *FFArm64) generateAddVecF31() { f.LDP("res+0(FP)", resPtr, aPtr) f.LDP("b+16(FP)", bPtr, n) + a := registers.PopV("a") + b := registers.PopV("b") + t := registers.PopV("t") + q := registers.PopV("q") + f.VMOVS("$const_q", q) f.VDUP(q.SAt(0), q.S4(), "broadcast q into "+string(q)) - q = q.S4() f.LABEL(loop) @@ -37,50 +37,54 @@ func (f *FFArm64) generateAddVecF31() { const offset = 4 * 4 // we process 4 uint32 at a time - f.VLD1_P(offset, aPtr, a) - f.VLD1_P(offset, bPtr, b) + f.VLD1_P(offset, aPtr, a.S4()) + f.VLD1_P(offset, bPtr, b.S4()) - f.VADD(a, b, b, "b = a + b") - f.VSUB(q, b, t, "t = q - b") - f.VUMIN(t, b, b, "b = min(t, b)") - f.VST1_P(b, resPtr, offset, "res = b") + f.VADD(a.S4(), b.S4(), b.S4(), "b = a + b") + f.VSUB(q.S4(), b.S4(), t.S4(), "t = q - b") + f.VUMIN(t.S4(), b.S4(), b.S4(), "b = min(t, b)") + f.VST1_P(b.S4(), resPtr, offset, "res = b") // decrement n f.SUB(1, n, n) f.JMP(loop) f.LABEL(done) + + registers.Push(resPtr, aPtr, bPtr, n) + registers.PushV(a, b, t, q) + f.RET() } func (f *FFArm64) generateSubVecF31() { - f.Comment("subVec(qq *uint32, res, a, b *Element, n uint64)") - registers := f.FnHeader("subVec", 0, 40) + f.Comment("subVec(res, a, b *Element, n uint64)") + registers := f.FnHeader("subVec", 0, 32) defer f.AssertCleanStack(0, 0) + defer registers.AssertCleanState() // registers resPtr := registers.Pop() - qPtr := registers.Pop() aPtr := registers.Pop() bPtr := registers.Pop() n := registers.Pop() - a := arm64.V0.S4() - b := arm64.V1.S4() - q := arm64.V2.S4() - t := arm64.V3.S4() - // labels loop := f.NewLabel("loop") done := f.NewLabel("done") // load arguments - f.LDP("qLane+0(FP)", qPtr, resPtr) - f.LDP("a+16(FP)", aPtr, bPtr) - f.MOVD("n+32(FP)", n) + f.LDP("res+0(FP)", resPtr, aPtr) + f.LDP("b+16(FP)", bPtr, n) + + a := registers.PopV("a") + b := registers.PopV("b") + t := registers.PopV("t") + q := registers.PopV("q") - f.VLD1(0, qPtr, q, "broadcast q into "+string(q)) + f.VMOVS("$const_q", q) + f.VDUP(q.SAt(0), q.S4(), "broadcast q into "+string(q)) f.LABEL(loop) @@ -88,19 +92,23 @@ func (f *FFArm64) generateSubVecF31() { const offset = 4 * 4 // we process 4 uint32 at a time - f.VLD1_P(offset, aPtr, a) - f.VLD1_P(offset, bPtr, b) + f.VLD1_P(offset, aPtr, a.S4()) + f.VLD1_P(offset, bPtr, b.S4()) - f.VSUB(b, a, b, "b = a - b") - f.VADD(b, q, t, "t = b + q") - f.VUMIN(t, b, b, "b = min(t, b)") - f.VST1_P(b, resPtr, offset, "res = b") + f.VSUB(b.S4(), a.S4(), b.S4(), "b = a - b") + f.VADD(b.S4(), q.S4(), t.S4(), "t = b + q") + f.VUMIN(t.S4(), b.S4(), b.S4(), "b = min(t, b)") + f.VST1_P(b.S4(), resPtr, offset, "res = b") // decrement n f.SUB(1, n, n) f.JMP(loop) f.LABEL(done) + + registers.Push(resPtr, aPtr, bPtr, n) + registers.PushV(a, b, q, t) + f.RET() } @@ -115,14 +123,14 @@ func (f *FFArm64) generateSumVecF31() { tPtr := registers.Pop() n := registers.Pop() - a1 := arm64.V0 - a2 := arm64.V1 - a3 := arm64.V2 - a4 := arm64.V3 - acc1 := arm64.V4 - acc2 := arm64.V5 - acc3 := arm64.V6 - acc4 := arm64.V7 + a1 := registers.PopV() + a2 := registers.PopV() + a3 := registers.PopV() + a4 := registers.PopV() + acc1 := registers.PopV() + acc2 := registers.PopV() + acc3 := registers.PopV() + acc4 := registers.PopV() f.Comment("zeroing accumulators") f.VMOVQ_cst(0, 0, acc1) diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go index 0f95084bcc..9fc0381f63 100644 --- a/field/generator/internal/templates/element/vector_ops_asm.go +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -148,27 +148,15 @@ const VectorOpsArm64 = VectorOpsPureGo const VectorOpsArm64F31 = ` - -// qLane is a vector with all elements set to q -// TODO figure out why the arm64 assembly to broadcast a scalar is not working -var qLane [4]uint32 -func init() { - qLane[0] = q - qLane[1] = q - qLane[2] = q - qLane[3] = q -} - //go:noescape func addVec(res, a, b *{{.ElementName}}, n uint64) //go:noescape -func subVec(qLane *uint32, res, a, b *{{.ElementName}}, n uint64) +func subVec(res, a, b *{{.ElementName}}, n uint64) //go:noescape func sumVec(t *uint64, a *{{.ElementName}}, n uint64) - // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -201,7 +189,7 @@ func (vector *Vector) Sub(a, b Vector) { } const blockSize = 4 - subVec(&qLane[0], &(*vector)[0], &a[0], &b[0], n/blockSize) + subVec(&(*vector)[0], &a[0], &b[0], n/blockSize) if n % blockSize != 0 { // call subVecGeneric on the rest start := n - n % blockSize diff --git a/field/koalabear/element_arm64.s b/field/koalabear/element_arm64.s index 746a2b753a..feecabef0a 100644 --- a/field/koalabear/element_arm64.s +++ b/field/koalabear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 6685189974678439935 +// We include the hash to force the Go compiler to recompile: 11749081742385060317 #include "../asm/element_31b_arm64.s" diff --git a/field/koalabear/vector_arm64.go b/field/koalabear/vector_arm64.go index 8f0777d5e6..f322e7043b 100644 --- a/field/koalabear/vector_arm64.go +++ b/field/koalabear/vector_arm64.go @@ -7,22 +7,11 @@ package koalabear -// qLane is a vector with all elements set to q -// TODO figure out why the arm64 assembly to broadcast a scalar is not working -var qLane [4]uint32 - -func init() { - qLane[0] = q - qLane[1] = q - qLane[2] = q - qLane[3] = q -} - //go:noescape func addVec(res, a, b *Element, n uint64) //go:noescape -func subVec(qLane *uint32, res, a, b *Element, n uint64) +func subVec(res, a, b *Element, n uint64) //go:noescape func sumVec(t *uint64, a *Element, n uint64) @@ -59,7 +48,7 @@ func (vector *Vector) Sub(a, b Vector) { } const blockSize = 4 - subVec(&qLane[0], &(*vector)[0], &a[0], &b[0], n/blockSize) + subVec(&(*vector)[0], &a[0], &b[0], n/blockSize) if n%blockSize != 0 { // call subVecGeneric on the rest start := n - n%blockSize From 5dcb7446970b07305567c96b41481461e03f3287 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 13 Dec 2024 12:54:12 -0600 Subject: [PATCH 66/74] checkpoint --- field/generator/asm/arm64/element_vec_F31.go | 67 +++++++++++++++++++ .../templates/element/vector_ops_asm.go | 19 +++++- 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/field/generator/asm/arm64/element_vec_F31.go b/field/generator/asm/arm64/element_vec_F31.go index cd22c5a643..0a21cce239 100644 --- a/field/generator/asm/arm64/element_vec_F31.go +++ b/field/generator/asm/arm64/element_vec_F31.go @@ -194,3 +194,70 @@ func (f *FFArm64) generateSumVecF31() { f.RET() } + +func (f *FFArm64) generateMulVecF31() { + f.Comment("mulVec(res, a, b *Element, n uint64)") + registers := f.FnHeader("mulVec", 0, 32) + defer f.AssertCleanStack(0, 0) + defer registers.AssertCleanState() + + // registers + resPtr := registers.Pop() + aPtr := registers.Pop() + bPtr := registers.Pop() + n := registers.Pop() + + // labels + loop := f.NewLabel("loop") + done := f.NewLabel("done") + + // load arguments + f.LDP("res+0(FP)", resPtr, aPtr) + f.LDP("b+16(FP)", bPtr, n) + + a := registers.PopV("a") + b := registers.PopV("b") + t := registers.PopV("t") + q := registers.PopV("q") + qInvNeg := registers.PopV("qInvNeg") + p1 := registers.PopV("p1") + + f.VMOVS("$const_q", q) + f.VDUP(q.SAt(0), q.S4(), "broadcast q into "+string(q)) + + f.VMOVS("$const_qInvNeg", qInvNeg) + f.VDUP(qInvNeg.SAt(0), qInvNeg.S4(), "broadcast qInvNeg into "+string(qInvNeg)) + + f.LABEL(loop) + + f.CBZ(n, done) + + const offset = 4 * 4 // we process 4 uint32 at a time + + f.VLD1_P(offset, aPtr, a.S4()) + f.VLD1_P(offset, bPtr, b.S4()) + + // let's compute p1 := a1 * b1 + f.VPMULL(a.S4(), b.S4(), p1.D2()) + // let's move the low words in t + f.VMOV(p1.D2(), t.D2()) + + f.VUSHLL2(0, a.S4(), a.D2(), "convert high words to 64 bits") + f.VUSHLL2(0, b.S4(), b.D2(), "convert high words to 64 bits") + + f.VMUL(a.S4(), b.S4(), b.S4(), "b = a * b") + f.VSUB(q.S4(), b.S4(), t.S4(), "t = q - b") + f.VUMIN(t.S4(), b.S4(), b.S4(), "b = min(t, b)") + f.VST1_P(b.S4(), resPtr, offset, "res = b") + + // decrement n + f.SUB(1, n, n) + f.JMP(loop) + + f.LABEL(done) + + registers.Push(resPtr, aPtr, bPtr, n) + registers.PushV(a, b, t, q, a1, b1) + + f.RET() +} diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go index 9fc0381f63..19be822c59 100644 --- a/field/generator/internal/templates/element/vector_ops_asm.go +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -157,6 +157,9 @@ func subVec(res, a, b *{{.ElementName}}, n uint64) //go:noescape func sumVec(t *uint64, a *{{.ElementName}}, n uint64) +//go:noescape +func mulVec(res, a, b *{{.ElementName}}, n uint64) + // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -238,7 +241,21 @@ func (vector *Vector) InnerProduct(other Vector) (res {{.ElementName}}) { // Mul multiplies two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) + if len(a) != len(b) || len(a) != len(*vector) { + panic("vector.Mul: vectors don't have the same length") + } + n := uint64(len(a)) + if n == 0 { + return + } + + const blockSize = 4 + mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize) + if n % blockSize != 0 { + // call mulVecGeneric on the rest + start := n - n % blockSize + mulVecGeneric((*vector)[start:], a[start:], b[start:]) + } } ` From 69dd02d2af7761d2d261cc2ac395d6e1a85c6fc7 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 13 Dec 2024 16:57:15 -0600 Subject: [PATCH 67/74] test: fix integration test --- field/generator/asm/amd64/build.go | 1 + field/generator/config/field_config.go | 2 +- field/generator/generator.go | 6 +++++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index 9ee9f1a4bc..3cce93670d 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -221,6 +221,7 @@ func GenerateCommonASM(w io.Writer, nbWords, nbBits int, hasVector bool) error { if nbBits == 31 { return GenerateF31ASM(f, hasVector) } else { + fmt.Printf("nbWords: %d, nbBits: %d\n", nbWords, nbBits) panic("not implemented") } diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index e4e35567bd..c34ec305c2 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -103,7 +103,7 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) } // pre compute field constants F.NbBits = bModulus.BitLen() - F.F31 = F.NbBits <= 31 + F.F31 = F.NbBits == 31 F.NbWords = len(bModulus.Bits()) F.NbWordsLastIndex = F.NbWords - 1 diff --git a/field/generator/generator.go b/field/generator/generator.go index 4feac2c76b..29943ef6e4 100644 --- a/field/generator/generator.go +++ b/field/generator/generator.go @@ -35,7 +35,11 @@ func GenerateFF(F *config.Field, outputDir string, options ...Option) error { } // generate field - if err := generateField(F, outputDir, cfg.asmConfig.IncludeDir, hashArm64, hashAMD64); err != nil { + asmIncludeDir := "" + if cfg.HasArm64() || cfg.HasAMD64() { + asmIncludeDir = cfg.asmConfig.IncludeDir + } + if err := generateField(F, outputDir, asmIncludeDir, hashArm64, hashAMD64); err != nil { return err } From 26d7a9d054e300738522f0efa6c3fc45486e5fd2 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 16 Dec 2024 09:31:34 -0600 Subject: [PATCH 68/74] checkpoint --- field/asm/element_31b_arm64.s | 24 ++++++++++++++++++++ field/babybear/element_arm64.s | 2 +- field/generator/asm/arm64/build.go | 2 +- field/generator/asm/arm64/element_vec_F31.go | 14 +++++------- field/koalabear/element_arm64.s | 2 +- 5 files changed, 33 insertions(+), 11 deletions(-) diff --git a/field/asm/element_31b_arm64.s b/field/asm/element_31b_arm64.s index 1ef1b5778c..e7d64b009b 100644 --- a/field/asm/element_31b_arm64.s +++ b/field/asm/element_31b_arm64.s @@ -103,3 +103,27 @@ done6: VADD V5.D2, V7.D2, V5.D2 // acc2 += acc4 VST2.P [V4.D2, V5.D2], 0(R1) // store acc1 and acc2 RET + +// mulVec(res, a, b *Element, n uint64) +TEXT ·mulVec(SB), NOFRAME|NOSPLIT, $0-32 + LDP res+0(FP), (R0, R1) + LDP b+16(FP), (R2, R3) + VMOVS $const_q, V3 + VDUP V3.S[0], V3.S4 // broadcast q into V3 + VMOVS $const_qInvNeg, V4 + VDUP V4.S[0], V4.S4 // broadcast qInvNeg into V4 + +loop7: + CBZ R3, done8 + VLD1.P 16(R1), [V0.S4] + VLD1.P 16(R2), [V1.S4] + VUSHLL2 $0, V0.S4, V0.D2 // convert high words to 64 bits + VUSHLL2 $0, V1.S4, V1.D2 // convert high words to 64 bits + VSUB V3.S4, V1.S4, V2.S4 // t = q - b + VUMIN V2.S4, V1.S4, V1.S4 // b = min(t, b) + VST1.P [V1.S4], 16(R0) // res = b + SUB $1, R3, R3 + JMP loop7 + +done8: + RET diff --git a/field/babybear/element_arm64.s b/field/babybear/element_arm64.s index feecabef0a..a1a38a0e7d 100644 --- a/field/babybear/element_arm64.s +++ b/field/babybear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 11749081742385060317 +// We include the hash to force the Go compiler to recompile: 6730327240678615452 #include "../asm/element_31b_arm64.s" diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index 0c7fdd77bb..2db1a857e6 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -234,7 +234,7 @@ func GenerateF31ASM(f *FFArm64, hasVector bool) error { f.generateAddVecF31() f.generateSubVecF31() f.generateSumVecF31() - // f.generateMulVecF31() + f.generateMulVecF31() // f.generateScalarMulVecF31() // f.generateInnerProdVecF31() diff --git a/field/generator/asm/arm64/element_vec_F31.go b/field/generator/asm/arm64/element_vec_F31.go index 355a7160ce..a90ff3a33c 100644 --- a/field/generator/asm/arm64/element_vec_F31.go +++ b/field/generator/asm/arm64/element_vec_F31.go @@ -199,7 +199,6 @@ func (f *FFArm64) generateMulVecF31() { f.Comment("mulVec(res, a, b *Element, n uint64)") registers := f.FnHeader("mulVec", 0, 32) defer f.AssertCleanStack(0, 0) - defer registers.AssertCleanState() // registers resPtr := registers.Pop() @@ -215,12 +214,11 @@ func (f *FFArm64) generateMulVecF31() { f.LDP("res+0(FP)", resPtr, aPtr) f.LDP("b+16(FP)", bPtr, n) - a := registers.PopV("a") - b := registers.PopV("b") - t := registers.PopV("t") - q := registers.PopV("q") - qInvNeg := registers.PopV("qInvNeg") - p1 := registers.PopV("p1") + a := registers.PopV() + b := registers.PopV() + t := registers.PopV() + q := registers.PopV() + qInvNeg := registers.PopV() f.VMOVS("$const_q", q) f.VDUP(q.SAt(0), q.S4(), "broadcast q into "+string(q)) @@ -238,7 +236,7 @@ func (f *FFArm64) generateMulVecF31() { f.VLD1_P(offset, bPtr, b.S4()) // let's compute p1 := a1 * b1 - f.VPMULL(a.S4(), b.S4(), p1.D2()) + // f.VPMULL(a.S4(), b.S4(), p1.D2()) // let's move the low words in t // f.VMOV(p1.D2(), t.D2()) diff --git a/field/koalabear/element_arm64.s b/field/koalabear/element_arm64.s index feecabef0a..a1a38a0e7d 100644 --- a/field/koalabear/element_arm64.s +++ b/field/koalabear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 11749081742385060317 +// We include the hash to force the Go compiler to recompile: 6730327240678615452 #include "../asm/element_31b_arm64.s" From 4f6dde1a0e464805c7900e7b37b279e069f37e0a Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Thu, 19 Dec 2024 10:16:54 -0600 Subject: [PATCH 69/74] checkpoint --- field/asm/element_31b_arm64.s | 62 +++++-- field/babybear/element_arm64.s | 2 +- field/generator/asm/arm64/element_vec_F31.go | 180 ++++++++++++++++--- field/koalabear/element_arm64.s | 2 +- 4 files changed, 208 insertions(+), 38 deletions(-) diff --git a/field/asm/element_31b_arm64.s b/field/asm/element_31b_arm64.s index e7d64b009b..9a1397397e 100644 --- a/field/asm/element_31b_arm64.s +++ b/field/asm/element_31b_arm64.s @@ -20,7 +20,7 @@ loop1: VLD1.P 16(R1), [a.S4] VLD1.P 16(R2), [b.S4] VADD a.S4, b.S4, b.S4 // b = a + b - VSUB q.S4, b.S4, t.S4 // t = q - b + VSUB q.S4, b.S4, t.S4 // t = b - q VUMIN t.S4, b.S4, b.S4 // b = min(t, b) VST1.P [b.S4], 16(R0) // res = b SUB $1, R3, R3 @@ -108,20 +108,58 @@ done6: TEXT ·mulVec(SB), NOFRAME|NOSPLIT, $0-32 LDP res+0(FP), (R0, R1) LDP b+16(FP), (R2, R3) - VMOVS $const_q, V3 - VDUP V3.S[0], V3.S4 // broadcast q into V3 - VMOVS $const_qInvNeg, V4 - VDUP V4.S[0], V4.S4 // broadcast qInvNeg into V4 + VMOVS $const_q, V0 + VDUP V0.D[0], V0.D2 // broadcast q into V0 + VMOVQ $0xffffffff, $0xffffffff, V1 loop7: CBZ R3, done8 - VLD1.P 16(R1), [V0.S4] - VLD1.P 16(R2), [V1.S4] - VUSHLL2 $0, V0.S4, V0.D2 // convert high words to 64 bits - VUSHLL2 $0, V1.S4, V1.D2 // convert high words to 64 bits - VSUB V3.S4, V1.S4, V2.S4 // t = q - b - VUMIN V2.S4, V1.S4, V1.S4 // b = min(t, b) - VST1.P [V1.S4], 16(R0) // res = b + MOVWU.P 4(R1), R4 + MOVWU.P 4(R1), R5 + MOVWU.P 4(R2), R6 + MOVWU.P 4(R2), R7 + MUL R4, R6, R8 + MUL R5, R7, R9 + VMOV R8, V2.D[0] + VMOV R9, V2.D[1] + VSHL $0x1f, V2.D2, V4.D2 + VSHL $0x18, V2.D2, V5.D2 + MOVWU.P 4(R1), R10 + MOVWU.P 4(R1), R11 + VSUB V5.D2, V4.D2, V4.D2 + VSUB V2.D2, V4.D2, V3.D2 + MOVWU.P 4(R2), R12 + MOVWU.P 4(R2), R13 + VAND V3.B16, V1.B16, V3.B16 + VSHL $0x1f, V3.D2, V4.D2 + VSHL $0x18, V3.D2, V5.D2 + VSUB V5.D2, V4.D2, V4.D2 + VADD V3.D2, V4.D2, V3.D2 + VADD V3.D2, V2.D2, V3.D2 + VUSHR $0x20, V3.D2, V3.D2 + VSUB V0.D2, V3.D2, V4.D2 // t = q - m + VUMIN V4.S4, V3.S4, V3.S4 // m = min(t, m) + VSHL $0x20, V3.D2, V3.D2 + VREV64 V3.S2, V3.S2 + MUL R10, R12, R14 + MUL R11, R13, R15 + VMOV R14, V6.D[0] + VMOV R15, V6.D[1] + VSHL $0x1f, V6.D2, V8.D2 + VSHL $0x18, V6.D2, V9.D2 + VSUB V9.D2, V8.D2, V8.D2 + VSUB V6.D2, V8.D2, V7.D2 + VAND V7.B16, V1.B16, V7.B16 + VSHL $0x1f, V7.D2, V8.D2 + VSHL $0x18, V7.D2, V9.D2 + VSUB V9.D2, V8.D2, V8.D2 + VADD V7.D2, V8.D2, V7.D2 + VADD V7.D2, V6.D2, V7.D2 + VUSHR $0x20, V7.D2, V7.D2 + VSUB V0.D2, V7.D2, V8.D2 // t = q - m + VUMIN V8.S4, V7.S4, V7.S4 // m = min(t, m) + VADD V7.S4, V3.S4, V3.S4 + VST1.P [V3.S4], 16(R0) // res = b SUB $1, R3, R3 JMP loop7 diff --git a/field/babybear/element_arm64.s b/field/babybear/element_arm64.s index a1a38a0e7d..7fabf7a26a 100644 --- a/field/babybear/element_arm64.s +++ b/field/babybear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 6730327240678615452 +// We include the hash to force the Go compiler to recompile: 7281632604730491830 #include "../asm/element_31b_arm64.s" diff --git a/field/generator/asm/arm64/element_vec_F31.go b/field/generator/asm/arm64/element_vec_F31.go index a90ff3a33c..2e58d17acc 100644 --- a/field/generator/asm/arm64/element_vec_F31.go +++ b/field/generator/asm/arm64/element_vec_F31.go @@ -41,7 +41,7 @@ func (f *FFArm64) generateAddVecF31() { f.VLD1_P(offset, bPtr, b.S4()) f.VADD(a.S4(), b.S4(), b.S4(), "b = a + b") - f.VSUB(q.S4(), b.S4(), t.S4(), "t = q - b") + f.VSUB(q.S4(), b.S4(), t.S4(), "t = b - q") f.VUMIN(t.S4(), b.S4(), b.S4(), "b = min(t, b)") f.VST1_P(b.S4(), resPtr, offset, "res = b") @@ -214,39 +214,174 @@ func (f *FFArm64) generateMulVecF31() { f.LDP("res+0(FP)", resPtr, aPtr) f.LDP("b+16(FP)", bPtr, n) - a := registers.PopV() - b := registers.PopV() - t := registers.PopV() + // a := registers.PopV() + // b := registers.PopV() + // t := registers.PopV() + // q := registers.PopV() + // qInvNeg := registers.PopV() + + // f.VMOVS("$const_q", q) + // f.VDUP(q.SAt(0), q.S4(), "broadcast q into "+string(q)) + + // f.VMOVS("$const_qInvNeg", qInvNeg) + // f.VDUP(qInvNeg.SAt(0), qInvNeg.S4(), "broadcast qInvNeg into "+string(qInvNeg)) + q := registers.PopV() - qInvNeg := registers.PopV() f.VMOVS("$const_q", q) - f.VDUP(q.SAt(0), q.S4(), "broadcast q into "+string(q)) + f.VDUP(q.DAt(0), q.D2(), "broadcast q into "+string(q)) - f.VMOVS("$const_qInvNeg", qInvNeg) - f.VDUP(qInvNeg.SAt(0), qInvNeg.S4(), "broadcast qInvNeg into "+string(qInvNeg)) + const maxUint32 = 0xFFFFFFFF + mask := registers.PopV() + f.VMOVQ_cst(maxUint32, maxUint32, mask) f.LABEL(loop) f.CBZ(n, done) - const offset = 4 * 4 // we process 4 uint32 at a time + a0 := registers.Pop() + a1 := registers.Pop() + b0 := registers.Pop() + b1 := registers.Pop() + r0 := registers.Pop() + r1 := registers.Pop() + + v := registers.PopV() + m := registers.PopV() + t1 := registers.PopV() + t2 := registers.PopV() + + a0_2 := registers.Pop() + a1_2 := registers.Pop() + b0_2 := registers.Pop() + b1_2 := registers.Pop() + r0_2 := registers.Pop() + r1_2 := registers.Pop() + v_2 := registers.PopV() + m_2 := registers.PopV() + t1_2 := registers.PopV() + t2_2 := registers.PopV() + + // let's do 2 by 2 to start with; + f.MOVWUP_Load(4, aPtr, a0) + f.MOVWUP_Load(4, aPtr, a1) + f.MOVWUP_Load(4, bPtr, b0) + f.MOVWUP_Load(4, bPtr, b1) + + f.MUL(a0, b0, r0) + f.MUL(a1, b1, r1) + + f.VMOV(r0, v.DAt(0)) + f.VMOV(r1, v.DAt(1)) + + // qInvNeg == 2**31 - 2**24 -1 + // so we shift left by 31, store in a vector + // we shift left by 24, store in a vector + // we subtract the two vectors + f.VSHL(31, v.D2(), t1.D2()) + f.VSHL(24, v.D2(), t2.D2()) + f.MOVWUP_Load(4, aPtr, a0_2) + f.MOVWUP_Load(4, aPtr, a1_2) + + f.VSUB(t2.D2(), t1.D2(), t1.D2()) + f.VSUB(v.D2(), t1.D2(), m.D2()) + f.MOVWUP_Load(4, bPtr, b0_2) + f.MOVWUP_Load(4, bPtr, b1_2) + + // here we just want to keep m=low bits(vRes) + f.VAND(m.B16(), mask.B16(), m.B16()) + + // q == 2**31 - 2**24 + 1 + f.VSHL(31, m.D2(), t1.D2()) + f.VSHL(24, m.D2(), t2.D2()) + f.VSUB(t2.D2(), t1.D2(), t1.D2()) + f.VADD(m.D2(), t1.D2(), m.D2()) + + f.VADD(m.D2(), v.D2(), m.D2()) + f.VUSHR(32, m.D2(), m.D2()) + + // now we do mod q if needed + f.VSUB(q.D2(), m.D2(), t1.D2(), "t = q - m") + f.VUMIN(t1.S4(), m.S4(), m.S4(), "m = min(t, m)") + + f.VSHL(32, m.D2(), m.D2()) + + // f.VMOV(m.DAt(0), r0) + // f.VMOV(m.DAt(1), r1) + + // f.MOVWUP_Store(r0, resPtr, 4) + // f.MOVWUP_Store(r1, resPtr, 4) + + f.MUL(a0_2, b0_2, r0_2) + f.MUL(a1_2, b1_2, r1_2) + + f.VMOV(r0_2, v_2.DAt(0)) + f.VMOV(r1_2, v_2.DAt(1)) + + // qInvNeg == 2**31 - 2**24 -1 + // so we shift left by 31, store in a vector + // we shift left by 24, store in a vector + // we subtract the two vectors + f.VSHL(31, v_2.D2(), t1_2.D2()) + f.VSHL(24, v_2.D2(), t2_2.D2()) + f.VSUB(t2_2.D2(), t1_2.D2(), t1_2.D2()) + f.VSUB(v_2.D2(), t1_2.D2(), m_2.D2()) + + // here we just want to keep m=low bits(vRes) + f.VAND(m_2.B16(), mask.B16(), m_2.B16()) + + // q == 2**31 - 2**24 + 1 + f.VSHL(31, m_2.D2(), t1_2.D2()) + f.VSHL(24, m_2.D2(), t2_2.D2()) + f.VSUB(t2_2.D2(), t1_2.D2(), t1_2.D2()) + f.VADD(m_2.D2(), t1_2.D2(), m_2.D2()) + + f.VADD(m_2.D2(), v_2.D2(), m_2.D2()) + f.VUSHR(32, m_2.D2(), m_2.D2()) + + // now we do mod q if needed + f.VSUB(q.D2(), m_2.D2(), t1_2.D2(), "t = q - m") + f.VUMIN(t1_2.S4(), m_2.S4(), m_2.S4(), "m = min(t, m)") + + f.VADD(m_2.S4(), m.S4(), m.S4()) + // f.VREV64(m.B16(), m.B16()) + + f.VST1_P(m.S4(), resPtr, 4*4, "res = b") + + // f.VMOV(m_2.DAt(0), r0_2) + // f.VMOV(m_2.DAt(1), r1_2) + + // f.MOVWUP_Store(r0_2, resPtr, 4) + // f.MOVWUP_Store(r1_2, resPtr, 4) + + // func montReduce(v uint64) uint32 { + // m := uint32(v) * qInvNeg + // t := uint32((v + uint64(m)*q) >> 32) + // if t >= q { + // t -= q + // } + // return t + // } - f.VLD1_P(offset, aPtr, a.S4()) - f.VLD1_P(offset, bPtr, b.S4()) + // g.VST1_P(vRes.D2(), resPtr, 0) - // let's compute p1 := a1 * b1 - // f.VPMULL(a.S4(), b.S4(), p1.D2()) - // let's move the low words in t - // f.VMOV(p1.D2(), t.D2()) + // const offset = 4 * 4 // we process 4 uint32 at a time - f.VUSHLL2(0, a.S4(), a.D2(), "convert high words to 64 bits") - f.VUSHLL2(0, b.S4(), b.D2(), "convert high words to 64 bits") + // f.VLD1_P(offset, aPtr, a.S4()) + // f.VLD1_P(offset, bPtr, b.S4()) - // f.VMUL(a.S4(), b.S4(), b.S4(), "b = a * b") - f.VSUB(q.S4(), b.S4(), t.S4(), "t = q - b") - f.VUMIN(t.S4(), b.S4(), b.S4(), "b = min(t, b)") - f.VST1_P(b.S4(), resPtr, offset, "res = b") + // // let's compute p1 := a1 * b1 + // // f.VPMULL(a.S4(), b.S4(), p1.D2()) + // // let's move the low words in t + // // f.VMOV(p1.D2(), t.D2()) + + // f.VUSHLL2(0, a.S4(), a.D2(), "convert high words to 64 bits") + // f.VUSHLL2(0, b.S4(), b.D2(), "convert high words to 64 bits") + + // // f.VMUL(a.S4(), b.S4(), b.S4(), "b = a * b") + // f.VSUB(q.S4(), b.S4(), t.S4(), "t = q - b") + // f.VUMIN(t.S4(), b.S4(), b.S4(), "b = min(t, b)") + // f.VST1_P(b.S4(), resPtr, offset, "res = b") // decrement n f.SUB(1, n, n) @@ -254,8 +389,5 @@ func (f *FFArm64) generateMulVecF31() { f.LABEL(done) - registers.Push(resPtr, aPtr, bPtr, n) - registers.PushV(a, b, t, q) //, a1, b1) - f.RET() } diff --git a/field/koalabear/element_arm64.s b/field/koalabear/element_arm64.s index a1a38a0e7d..7fabf7a26a 100644 --- a/field/koalabear/element_arm64.s +++ b/field/koalabear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 6730327240678615452 +// We include the hash to force the Go compiler to recompile: 7281632604730491830 #include "../asm/element_31b_arm64.s" From 50e000b2ee3da7b70d74f9c0ddd4f47ccd461283 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Thu, 19 Dec 2024 11:55:14 -0600 Subject: [PATCH 70/74] feat: remove neon mul, not enough support in golang --- field/asm/element_31b_arm64.s | 61 ------ field/babybear/element_arm64.s | 2 +- field/babybear/vector_arm64.go | 19 +- field/generator/asm/arm64/build.go | 2 +- field/generator/asm/arm64/element_vec_F31.go | 197 ------------------ .../templates/element/vector_ops_asm.go | 19 +- field/koalabear/element_arm64.s | 2 +- field/koalabear/vector_arm64.go | 19 +- go.mod | 4 +- go.sum | 2 + 10 files changed, 9 insertions(+), 318 deletions(-) diff --git a/field/asm/element_31b_arm64.s b/field/asm/element_31b_arm64.s index cffe20c34a..87ca1c41f1 100644 --- a/field/asm/element_31b_arm64.s +++ b/field/asm/element_31b_arm64.s @@ -103,64 +103,3 @@ done6: VADD V5.D2, V7.D2, V5.D2 // acc2 += acc4 VST2.P [V4.D2, V5.D2], 0(R1) // store acc1 and acc2 RET - -// mulVec(res, a, b *Element, n uint64) -TEXT ·mulVec(SB), NOFRAME|NOSPLIT, $0-32 - LDP res+0(FP), (R0, R1) - LDP b+16(FP), (R2, R3) - VMOVS $const_q, V0 - VDUP V0.D[0], V0.D2 // broadcast q into V0 - VMOVQ $0xffffffff, $0xffffffff, V1 - -loop7: - CBZ R3, done8 - MOVWU.P 4(R1), R4 - MOVWU.P 4(R1), R5 - MOVWU.P 4(R2), R6 - MOVWU.P 4(R2), R7 - MUL R4, R6, R8 - MUL R5, R7, R9 - VMOV R8, V2.D[0] - VMOV R9, V2.D[1] - VSHL $0x1f, V2.D2, V4.D2 - VSHL $0x18, V2.D2, V5.D2 - MOVWU.P 4(R1), R10 - MOVWU.P 4(R1), R11 - VSUB V5.D2, V4.D2, V4.D2 - VSUB V2.D2, V4.D2, V3.D2 - MOVWU.P 4(R2), R12 - MOVWU.P 4(R2), R13 - VAND V3.B16, V1.B16, V3.B16 - VSHL $0x1f, V3.D2, V4.D2 - VSHL $0x18, V3.D2, V5.D2 - VSUB V5.D2, V4.D2, V4.D2 - VADD V3.D2, V4.D2, V3.D2 - VADD V3.D2, V2.D2, V3.D2 - VUSHR $0x20, V3.D2, V3.D2 - VSUB V0.D2, V3.D2, V4.D2 // t = q - m - VUMIN V4.S4, V3.S4, V3.S4 // m = min(t, m) - VSHL $0x20, V3.D2, V3.D2 - MUL R10, R12, R14 - MUL R11, R13, R15 - VMOV R14, V6.D[0] - VMOV R15, V6.D[1] - VSHL $0x1f, V6.D2, V8.D2 - VSHL $0x18, V6.D2, V9.D2 - VSUB V9.D2, V8.D2, V8.D2 - VSUB V6.D2, V8.D2, V7.D2 - VAND V7.B16, V1.B16, V7.B16 - VSHL $0x1f, V7.D2, V8.D2 - VSHL $0x18, V7.D2, V9.D2 - VSUB V9.D2, V8.D2, V8.D2 - VADD V7.D2, V8.D2, V7.D2 - VADD V7.D2, V6.D2, V7.D2 - VUSHR $0x20, V7.D2, V7.D2 - VSUB V0.D2, V7.D2, V8.D2 // t = q - m - VUMIN V8.S4, V7.S4, V7.S4 // m = min(t, m) - VADD V7.S4, V3.S4, V3.S4 - VST1.P [V3.S4], 16(R0) // res = b - SUB $1, R3, R3 - JMP loop7 - -done8: - RET diff --git a/field/babybear/element_arm64.s b/field/babybear/element_arm64.s index 968f53d4fc..cc71503d82 100644 --- a/field/babybear/element_arm64.s +++ b/field/babybear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 4254762842095097722 +// We include the hash to force the Go compiler to recompile: 5443918388234640397 #include "../asm/element_31b_arm64.s" diff --git a/field/babybear/vector_arm64.go b/field/babybear/vector_arm64.go index f1f7a48992..ec387191ed 100644 --- a/field/babybear/vector_arm64.go +++ b/field/babybear/vector_arm64.go @@ -16,9 +16,6 @@ func subVec(res, a, b *Element, n uint64) //go:noescape func sumVec(t *uint64, a *Element, n uint64) -//go:noescape -func mulVec(res, a, b *Element, n uint64) - // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -100,19 +97,5 @@ func (vector *Vector) InnerProduct(other Vector) (res Element) { // Mul multiplies two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Mul(a, b Vector) { - if len(a) != len(b) || len(a) != len(*vector) { - panic("vector.Mul: vectors don't have the same length") - } - n := uint64(len(a)) - if n == 0 { - return - } - - const blockSize = 4 - mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize) - if n%blockSize != 0 { - // call mulVecGeneric on the rest - start := n - n%blockSize - mulVecGeneric((*vector)[start:], a[start:], b[start:]) - } + mulVecGeneric(*vector, a, b) } diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index 2db1a857e6..0c7fdd77bb 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -234,7 +234,7 @@ func GenerateF31ASM(f *FFArm64, hasVector bool) error { f.generateAddVecF31() f.generateSubVecF31() f.generateSumVecF31() - f.generateMulVecF31() + // f.generateMulVecF31() // f.generateScalarMulVecF31() // f.generateInnerProdVecF31() diff --git a/field/generator/asm/arm64/element_vec_F31.go b/field/generator/asm/arm64/element_vec_F31.go index 2e58d17acc..8f1490d2d5 100644 --- a/field/generator/asm/arm64/element_vec_F31.go +++ b/field/generator/asm/arm64/element_vec_F31.go @@ -194,200 +194,3 @@ func (f *FFArm64) generateSumVecF31() { f.RET() } - -func (f *FFArm64) generateMulVecF31() { - f.Comment("mulVec(res, a, b *Element, n uint64)") - registers := f.FnHeader("mulVec", 0, 32) - defer f.AssertCleanStack(0, 0) - - // registers - resPtr := registers.Pop() - aPtr := registers.Pop() - bPtr := registers.Pop() - n := registers.Pop() - - // labels - loop := f.NewLabel("loop") - done := f.NewLabel("done") - - // load arguments - f.LDP("res+0(FP)", resPtr, aPtr) - f.LDP("b+16(FP)", bPtr, n) - - // a := registers.PopV() - // b := registers.PopV() - // t := registers.PopV() - // q := registers.PopV() - // qInvNeg := registers.PopV() - - // f.VMOVS("$const_q", q) - // f.VDUP(q.SAt(0), q.S4(), "broadcast q into "+string(q)) - - // f.VMOVS("$const_qInvNeg", qInvNeg) - // f.VDUP(qInvNeg.SAt(0), qInvNeg.S4(), "broadcast qInvNeg into "+string(qInvNeg)) - - q := registers.PopV() - - f.VMOVS("$const_q", q) - f.VDUP(q.DAt(0), q.D2(), "broadcast q into "+string(q)) - - const maxUint32 = 0xFFFFFFFF - mask := registers.PopV() - f.VMOVQ_cst(maxUint32, maxUint32, mask) - - f.LABEL(loop) - - f.CBZ(n, done) - - a0 := registers.Pop() - a1 := registers.Pop() - b0 := registers.Pop() - b1 := registers.Pop() - r0 := registers.Pop() - r1 := registers.Pop() - - v := registers.PopV() - m := registers.PopV() - t1 := registers.PopV() - t2 := registers.PopV() - - a0_2 := registers.Pop() - a1_2 := registers.Pop() - b0_2 := registers.Pop() - b1_2 := registers.Pop() - r0_2 := registers.Pop() - r1_2 := registers.Pop() - v_2 := registers.PopV() - m_2 := registers.PopV() - t1_2 := registers.PopV() - t2_2 := registers.PopV() - - // let's do 2 by 2 to start with; - f.MOVWUP_Load(4, aPtr, a0) - f.MOVWUP_Load(4, aPtr, a1) - f.MOVWUP_Load(4, bPtr, b0) - f.MOVWUP_Load(4, bPtr, b1) - - f.MUL(a0, b0, r0) - f.MUL(a1, b1, r1) - - f.VMOV(r0, v.DAt(0)) - f.VMOV(r1, v.DAt(1)) - - // qInvNeg == 2**31 - 2**24 -1 - // so we shift left by 31, store in a vector - // we shift left by 24, store in a vector - // we subtract the two vectors - f.VSHL(31, v.D2(), t1.D2()) - f.VSHL(24, v.D2(), t2.D2()) - f.MOVWUP_Load(4, aPtr, a0_2) - f.MOVWUP_Load(4, aPtr, a1_2) - - f.VSUB(t2.D2(), t1.D2(), t1.D2()) - f.VSUB(v.D2(), t1.D2(), m.D2()) - f.MOVWUP_Load(4, bPtr, b0_2) - f.MOVWUP_Load(4, bPtr, b1_2) - - // here we just want to keep m=low bits(vRes) - f.VAND(m.B16(), mask.B16(), m.B16()) - - // q == 2**31 - 2**24 + 1 - f.VSHL(31, m.D2(), t1.D2()) - f.VSHL(24, m.D2(), t2.D2()) - f.VSUB(t2.D2(), t1.D2(), t1.D2()) - f.VADD(m.D2(), t1.D2(), m.D2()) - - f.VADD(m.D2(), v.D2(), m.D2()) - f.VUSHR(32, m.D2(), m.D2()) - - // now we do mod q if needed - f.VSUB(q.D2(), m.D2(), t1.D2(), "t = q - m") - f.VUMIN(t1.S4(), m.S4(), m.S4(), "m = min(t, m)") - - f.VSHL(32, m.D2(), m.D2()) - - // f.VMOV(m.DAt(0), r0) - // f.VMOV(m.DAt(1), r1) - - // f.MOVWUP_Store(r0, resPtr, 4) - // f.MOVWUP_Store(r1, resPtr, 4) - - f.MUL(a0_2, b0_2, r0_2) - f.MUL(a1_2, b1_2, r1_2) - - f.VMOV(r0_2, v_2.DAt(0)) - f.VMOV(r1_2, v_2.DAt(1)) - - // qInvNeg == 2**31 - 2**24 -1 - // so we shift left by 31, store in a vector - // we shift left by 24, store in a vector - // we subtract the two vectors - f.VSHL(31, v_2.D2(), t1_2.D2()) - f.VSHL(24, v_2.D2(), t2_2.D2()) - f.VSUB(t2_2.D2(), t1_2.D2(), t1_2.D2()) - f.VSUB(v_2.D2(), t1_2.D2(), m_2.D2()) - - // here we just want to keep m=low bits(vRes) - f.VAND(m_2.B16(), mask.B16(), m_2.B16()) - - // q == 2**31 - 2**24 + 1 - f.VSHL(31, m_2.D2(), t1_2.D2()) - f.VSHL(24, m_2.D2(), t2_2.D2()) - f.VSUB(t2_2.D2(), t1_2.D2(), t1_2.D2()) - f.VADD(m_2.D2(), t1_2.D2(), m_2.D2()) - - f.VADD(m_2.D2(), v_2.D2(), m_2.D2()) - f.VUSHR(32, m_2.D2(), m_2.D2()) - - // now we do mod q if needed - f.VSUB(q.D2(), m_2.D2(), t1_2.D2(), "t = q - m") - f.VUMIN(t1_2.S4(), m_2.S4(), m_2.S4(), "m = min(t, m)") - - f.VADD(m_2.S4(), m.S4(), m.S4()) - // f.VREV64(m.B16(), m.B16()) - - f.VST1_P(m.S4(), resPtr, 4*4, "res = b") - - // f.VMOV(m_2.DAt(0), r0_2) - // f.VMOV(m_2.DAt(1), r1_2) - - // f.MOVWUP_Store(r0_2, resPtr, 4) - // f.MOVWUP_Store(r1_2, resPtr, 4) - - // func montReduce(v uint64) uint32 { - // m := uint32(v) * qInvNeg - // t := uint32((v + uint64(m)*q) >> 32) - // if t >= q { - // t -= q - // } - // return t - // } - - // g.VST1_P(vRes.D2(), resPtr, 0) - - // const offset = 4 * 4 // we process 4 uint32 at a time - - // f.VLD1_P(offset, aPtr, a.S4()) - // f.VLD1_P(offset, bPtr, b.S4()) - - // // let's compute p1 := a1 * b1 - // // f.VPMULL(a.S4(), b.S4(), p1.D2()) - // // let's move the low words in t - // // f.VMOV(p1.D2(), t.D2()) - - // f.VUSHLL2(0, a.S4(), a.D2(), "convert high words to 64 bits") - // f.VUSHLL2(0, b.S4(), b.D2(), "convert high words to 64 bits") - - // // f.VMUL(a.S4(), b.S4(), b.S4(), "b = a * b") - // f.VSUB(q.S4(), b.S4(), t.S4(), "t = q - b") - // f.VUMIN(t.S4(), b.S4(), b.S4(), "b = min(t, b)") - // f.VST1_P(b.S4(), resPtr, offset, "res = b") - - // decrement n - f.SUB(1, n, n) - f.JMP(loop) - - f.LABEL(done) - - f.RET() -} diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go index e31e88ce58..c509764c56 100644 --- a/field/generator/internal/templates/element/vector_ops_asm.go +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -157,9 +157,6 @@ func subVec(res, a, b *{{.ElementName}}, n uint64) //go:noescape func sumVec(t *uint64, a *{{.ElementName}}, n uint64) -//go:noescape -func mulVec(res, a, b *{{.ElementName}}, n uint64) - // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -241,21 +238,7 @@ func (vector *Vector) InnerProduct(other Vector) (res {{.ElementName}}) { // Mul multiplies two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Mul(a, b Vector) { - if len(a) != len(b) || len(a) != len(*vector) { - panic("vector.Mul: vectors don't have the same length") - } - n := uint64(len(a)) - if n == 0 { - return - } - - const blockSize = 4 - mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize) - if n % blockSize != 0 { - // call mulVecGeneric on the rest - start := n - n % blockSize - mulVecGeneric((*vector)[start:], a[start:], b[start:]) - } + mulVecGeneric(*vector, a, b) } ` diff --git a/field/koalabear/element_arm64.s b/field/koalabear/element_arm64.s index 968f53d4fc..cc71503d82 100644 --- a/field/koalabear/element_arm64.s +++ b/field/koalabear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 4254762842095097722 +// We include the hash to force the Go compiler to recompile: 5443918388234640397 #include "../asm/element_31b_arm64.s" diff --git a/field/koalabear/vector_arm64.go b/field/koalabear/vector_arm64.go index ca3492e953..f322e7043b 100644 --- a/field/koalabear/vector_arm64.go +++ b/field/koalabear/vector_arm64.go @@ -16,9 +16,6 @@ func subVec(res, a, b *Element, n uint64) //go:noescape func sumVec(t *uint64, a *Element, n uint64) -//go:noescape -func mulVec(res, a, b *Element, n uint64) - // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { @@ -100,19 +97,5 @@ func (vector *Vector) InnerProduct(other Vector) (res Element) { // Mul multiplies two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Mul(a, b Vector) { - if len(a) != len(b) || len(a) != len(*vector) { - panic("vector.Mul: vectors don't have the same length") - } - n := uint64(len(a)) - if n == 0 { - return - } - - const blockSize = 4 - mulVec(&(*vector)[0], &a[0], &b[0], n/blockSize) - if n%blockSize != 0 { - // call mulVecGeneric on the rest - start := n - n%blockSize - mulVecGeneric((*vector)[start:], a[start:], b[start:]) - } + mulVecGeneric(*vector, a, b) } diff --git a/go.mod b/go.mod index da78c651a4..925b44c359 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/bits-and-blooms/bitset v1.14.2 - github.com/consensys/bavard v0.0.0 + github.com/consensys/bavard v0.1.25 github.com/leanovate/gopter v0.2.11 github.com/mmcloughlin/addchain v0.4.0 github.com/spf13/cobra v1.8.1 @@ -15,8 +15,6 @@ require ( gopkg.in/yaml.v2 v2.4.0 ) -replace github.com/consensys/bavard => ../bavard - require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index 489a6fcef7..798885eed6 100644 --- a/go.sum +++ b/go.sum @@ -55,6 +55,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/consensys/bavard v0.1.25 h1:5YcSBnp03/HvfpKaIQLr/ecspTp2k8YNR5rQLOWvUyc= +github.com/consensys/bavard v0.1.25/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= From fa8cf7b117c4b3e172cea01d1e271e6ea7640445 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Thu, 19 Dec 2024 11:57:06 -0600 Subject: [PATCH 71/74] style: remove reg aliases in asm --- field/asm/element_31b_arm64.s | 58 +++++++------------- field/babybear/element_arm64.s | 2 +- field/generator/asm/arm64/build.go | 3 - field/generator/asm/arm64/element_vec_F31.go | 19 +++---- field/koalabear/element_arm64.s | 2 +- 5 files changed, 30 insertions(+), 54 deletions(-) diff --git a/field/asm/element_31b_arm64.s b/field/asm/element_31b_arm64.s index 87ca1c41f1..d01a3dcf73 100644 --- a/field/asm/element_31b_arm64.s +++ b/field/asm/element_31b_arm64.s @@ -5,62 +5,44 @@ // addVec(res, a, b *Element, n uint64) TEXT ·addVec(SB), NOFRAME|NOSPLIT, $0-32 - LDP res+0(FP), (R0, R1) - LDP b+16(FP), (R2, R3) - -#define a V0 -#define b V1 -#define t V2 -#define q V3 - VMOVS $const_q, q - VDUP q.S[0], q.S4 // broadcast q into q + LDP res+0(FP), (R0, R1) + LDP b+16(FP), (R2, R3) + VMOVS $const_q, V3 + VDUP V3.S[0], V3.S4 // broadcast q into V3 loop1: CBZ R3, done2 - VLD1.P 16(R1), [a.S4] - VLD1.P 16(R2), [b.S4] - VADD a.S4, b.S4, b.S4 // b = a + b - VSUB q.S4, b.S4, t.S4 // t = b - q - VUMIN t.S4, b.S4, b.S4 // b = min(t, b) - VST1.P [b.S4], 16(R0) // res = b + VLD1.P 16(R1), [V0.S4] + VLD1.P 16(R2), [V1.S4] + VADD V0.S4, V1.S4, V1.S4 // b = a + b + VSUB V3.S4, V1.S4, V2.S4 // t = b - q + VUMIN V2.S4, V1.S4, V1.S4 // b = min(t, b) + VST1.P [V1.S4], 16(R0) // res = b SUB $1, R3, R3 JMP loop1 done2: -#undef a -#undef b -#undef t -#undef q RET // subVec(res, a, b *Element, n uint64) TEXT ·subVec(SB), NOFRAME|NOSPLIT, $0-32 - LDP res+0(FP), (R0, R1) - LDP b+16(FP), (R2, R3) - -#define a V0 -#define b V1 -#define t V2 -#define q V3 - VMOVS $const_q, q - VDUP q.S[0], q.S4 // broadcast q into q + LDP res+0(FP), (R0, R1) + LDP b+16(FP), (R2, R3) + VMOVS $const_q, V3 + VDUP V3.S[0], V3.S4 // broadcast q into V3 loop3: CBZ R3, done4 - VLD1.P 16(R1), [a.S4] - VLD1.P 16(R2), [b.S4] - VSUB b.S4, a.S4, b.S4 // b = a - b - VADD b.S4, q.S4, t.S4 // t = b + q - VUMIN t.S4, b.S4, b.S4 // b = min(t, b) - VST1.P [b.S4], 16(R0) // res = b + VLD1.P 16(R1), [V0.S4] + VLD1.P 16(R2), [V1.S4] + VSUB V1.S4, V0.S4, V1.S4 // b = a - b + VADD V1.S4, V3.S4, V2.S4 // t = b + q + VUMIN V2.S4, V1.S4, V1.S4 // b = min(t, b) + VST1.P [V1.S4], 16(R0) // res = b SUB $1, R3, R3 JMP loop3 done4: -#undef a -#undef b -#undef q -#undef t RET // sumVec(t *uint64, a *[]uint32, n uint64) res = sum(a[0...n]) diff --git a/field/babybear/element_arm64.s b/field/babybear/element_arm64.s index cc71503d82..4117788ab0 100644 --- a/field/babybear/element_arm64.s +++ b/field/babybear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 5443918388234640397 +// We include the hash to force the Go compiler to recompile: 18135195997225779195 #include "../asm/element_31b_arm64.s" diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index 0c7fdd77bb..9137d0372f 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -234,9 +234,6 @@ func GenerateF31ASM(f *FFArm64, hasVector bool) error { f.generateAddVecF31() f.generateSubVecF31() f.generateSumVecF31() - // f.generateMulVecF31() - // f.generateScalarMulVecF31() - // f.generateInnerProdVecF31() return nil } diff --git a/field/generator/asm/arm64/element_vec_F31.go b/field/generator/asm/arm64/element_vec_F31.go index 8f1490d2d5..f0c5df646a 100644 --- a/field/generator/asm/arm64/element_vec_F31.go +++ b/field/generator/asm/arm64/element_vec_F31.go @@ -6,11 +6,9 @@ func (f *FFArm64) generateAddVecF31() { f.Comment("addVec(res, a, b *Element, n uint64)") registers := f.FnHeader("addVec", 0, 32) defer f.AssertCleanStack(0, 0) - defer registers.AssertCleanState() // registers resPtr := registers.Pop() - // qqPtr := registers.Pop() aPtr := registers.Pop() bPtr := registers.Pop() n := registers.Pop() @@ -23,10 +21,10 @@ func (f *FFArm64) generateAddVecF31() { f.LDP("res+0(FP)", resPtr, aPtr) f.LDP("b+16(FP)", bPtr, n) - a := registers.PopV("a") - b := registers.PopV("b") - t := registers.PopV("t") - q := registers.PopV("q") + a := registers.PopV() + b := registers.PopV() + t := registers.PopV() + q := registers.PopV() f.VMOVS("$const_q", q) f.VDUP(q.SAt(0), q.S4(), "broadcast q into "+string(q)) @@ -62,7 +60,6 @@ func (f *FFArm64) generateSubVecF31() { f.Comment("subVec(res, a, b *Element, n uint64)") registers := f.FnHeader("subVec", 0, 32) defer f.AssertCleanStack(0, 0) - defer registers.AssertCleanState() // registers resPtr := registers.Pop() @@ -78,10 +75,10 @@ func (f *FFArm64) generateSubVecF31() { f.LDP("res+0(FP)", resPtr, aPtr) f.LDP("b+16(FP)", bPtr, n) - a := registers.PopV("a") - b := registers.PopV("b") - t := registers.PopV("t") - q := registers.PopV("q") + a := registers.PopV() + b := registers.PopV() + t := registers.PopV() + q := registers.PopV() f.VMOVS("$const_q", q) f.VDUP(q.SAt(0), q.S4(), "broadcast q into "+string(q)) diff --git a/field/koalabear/element_arm64.s b/field/koalabear/element_arm64.s index cc71503d82..4117788ab0 100644 --- a/field/koalabear/element_arm64.s +++ b/field/koalabear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 5443918388234640397 +// We include the hash to force the Go compiler to recompile: 18135195997225779195 #include "../asm/element_31b_arm64.s" From ca6efa40b7087f26ee1169bea3399b1c32252a57 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Thu, 19 Dec 2024 12:01:49 -0600 Subject: [PATCH 72/74] docs: add more context to asm arm doc --- field/asm/element_31b_arm64.s | 3 +++ field/babybear/element_arm64.s | 2 +- field/babybear/vector_arm64.go | 15 +++++++++------ field/generator/asm/arm64/element_vec_F31.go | 3 +++ .../internal/templates/element/vector_ops_asm.go | 15 +++++++++------ field/koalabear/element_arm64.s | 2 +- field/koalabear/vector_arm64.go | 15 +++++++++------ 7 files changed, 35 insertions(+), 20 deletions(-) diff --git a/field/asm/element_31b_arm64.s b/field/asm/element_31b_arm64.s index d01a3dcf73..6581ba0409 100644 --- a/field/asm/element_31b_arm64.s +++ b/field/asm/element_31b_arm64.s @@ -4,6 +4,7 @@ #include "go_asm.h" // addVec(res, a, b *Element, n uint64) +// n is the number of blocks of 4 uint32 to process TEXT ·addVec(SB), NOFRAME|NOSPLIT, $0-32 LDP res+0(FP), (R0, R1) LDP b+16(FP), (R2, R3) @@ -25,6 +26,7 @@ done2: RET // subVec(res, a, b *Element, n uint64) +// n is the number of blocks of 4 uint32 to process TEXT ·subVec(SB), NOFRAME|NOSPLIT, $0-32 LDP res+0(FP), (R0, R1) LDP b+16(FP), (R2, R3) @@ -46,6 +48,7 @@ done4: RET // sumVec(t *uint64, a *[]uint32, n uint64) res = sum(a[0...n]) +// n is the number of blocks of 16 uint32 to process TEXT ·sumVec(SB), NOFRAME|NOSPLIT, $0-24 // zeroing accumulators VMOVQ $0, $0, V4 diff --git a/field/babybear/element_arm64.s b/field/babybear/element_arm64.s index 4117788ab0..f3ddfa315e 100644 --- a/field/babybear/element_arm64.s +++ b/field/babybear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 18135195997225779195 +// We include the hash to force the Go compiler to recompile: 8620676634583589757 #include "../asm/element_31b_arm64.s" diff --git a/field/babybear/vector_arm64.go b/field/babybear/vector_arm64.go index ec387191ed..3ab1b562ff 100644 --- a/field/babybear/vector_arm64.go +++ b/field/babybear/vector_arm64.go @@ -56,12 +56,6 @@ func (vector *Vector) Sub(a, b Vector) { } } -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - // Sum computes the sum of all elements in the vector. func (vector *Vector) Sum() (res Element) { n := uint64(len(*vector)) @@ -87,6 +81,15 @@ func (vector *Vector) Sum() (res Element) { return } +// note: unfortunately, as of Dec. 2024, Golang doesn't support enough NEON instructions +// for these to be worth it in assembly. Will hopefully revisit in future versions. + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + // InnerProduct computes the inner product of two vectors. // It panics if the vectors don't have the same length. func (vector *Vector) InnerProduct(other Vector) (res Element) { diff --git a/field/generator/asm/arm64/element_vec_F31.go b/field/generator/asm/arm64/element_vec_F31.go index f0c5df646a..6e6a818c53 100644 --- a/field/generator/asm/arm64/element_vec_F31.go +++ b/field/generator/asm/arm64/element_vec_F31.go @@ -4,6 +4,7 @@ import "github.com/consensys/bavard/arm64" func (f *FFArm64) generateAddVecF31() { f.Comment("addVec(res, a, b *Element, n uint64)") + f.Comment("n is the number of blocks of 4 uint32 to process") registers := f.FnHeader("addVec", 0, 32) defer f.AssertCleanStack(0, 0) @@ -58,6 +59,7 @@ func (f *FFArm64) generateAddVecF31() { func (f *FFArm64) generateSubVecF31() { f.Comment("subVec(res, a, b *Element, n uint64)") + f.Comment("n is the number of blocks of 4 uint32 to process") registers := f.FnHeader("subVec", 0, 32) defer f.AssertCleanStack(0, 0) @@ -112,6 +114,7 @@ func (f *FFArm64) generateSubVecF31() { func (f *FFArm64) generateSumVecF31() { f.Comment("sumVec(t *uint64, a *[]uint32, n uint64) res = sum(a[0...n])") + f.Comment("n is the number of blocks of 16 uint32 to process") registers := f.FnHeader("sumVec", 0, 3*8) defer f.AssertCleanStack(0, 0) diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go index c509764c56..f012489e51 100644 --- a/field/generator/internal/templates/element/vector_ops_asm.go +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -197,12 +197,6 @@ func (vector *Vector) Sub(a, b Vector) { } } -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { - scalarMulVecGeneric(*vector, a, b) -} - // Sum computes the sum of all elements in the vector. func (vector *Vector) Sum() (res {{.ElementName}}) { n := uint64(len(*vector)) @@ -228,6 +222,15 @@ func (vector *Vector) Sum() (res {{.ElementName}}) { return } +// note: unfortunately, as of Dec. 2024, Golang doesn't support enough NEON instructions +// for these to be worth it in assembly. Will hopefully revisit in future versions. + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { + scalarMulVecGeneric(*vector, a, b) +} + // InnerProduct computes the inner product of two vectors. // It panics if the vectors don't have the same length. func (vector *Vector) InnerProduct(other Vector) (res {{.ElementName}}) { diff --git a/field/koalabear/element_arm64.s b/field/koalabear/element_arm64.s index 4117788ab0..f3ddfa315e 100644 --- a/field/koalabear/element_arm64.s +++ b/field/koalabear/element_arm64.s @@ -5,6 +5,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 18135195997225779195 +// We include the hash to force the Go compiler to recompile: 8620676634583589757 #include "../asm/element_31b_arm64.s" diff --git a/field/koalabear/vector_arm64.go b/field/koalabear/vector_arm64.go index f322e7043b..dd410b2058 100644 --- a/field/koalabear/vector_arm64.go +++ b/field/koalabear/vector_arm64.go @@ -56,12 +56,6 @@ func (vector *Vector) Sub(a, b Vector) { } } -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - // Sum computes the sum of all elements in the vector. func (vector *Vector) Sum() (res Element) { n := uint64(len(*vector)) @@ -87,6 +81,15 @@ func (vector *Vector) Sum() (res Element) { return } +// note: unfortunately, as of Dec. 2024, Golang doesn't support enough NEON instructions +// for these to be worth it in assembly. Will hopefully revisit in future versions. + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + // InnerProduct computes the inner product of two vectors. // It panics if the vectors don't have the same length. func (vector *Vector) InnerProduct(other Vector) (res Element) { From be88799debaf56957ef4591b6189921e927a88dd Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Thu, 19 Dec 2024 12:04:20 -0600 Subject: [PATCH 73/74] style: minor code cleaning --- field/generator/asm/amd64/build.go | 2 -- field/generator/asm/arm64/build.go | 1 - 2 files changed, 3 deletions(-) diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index 3cce93670d..a5407b9d14 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -221,10 +221,8 @@ func GenerateCommonASM(w io.Writer, nbWords, nbBits int, hasVector bool) error { if nbBits == 31 { return GenerateF31ASM(f, hasVector) } else { - fmt.Printf("nbWords: %d, nbBits: %d\n", nbWords, nbBits) panic("not implemented") } - } f.GenerateReduceDefine() diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index 9137d0372f..5b8d1893a0 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -61,7 +61,6 @@ func GenerateCommonASM(w io.Writer, nbWords, nbBits int, hasVector bool) error { } else { panic("not implemented") } - } if f.NbWords%2 != 0 { From 7048ffe252383c704914927ca1bf804568aafce1 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Thu, 19 Dec 2024 18:21:33 +0000 Subject: [PATCH 74/74] feat: fix template for amd --- field/babybear/vector_amd64.go | 1 + field/generator/internal/templates/element/vector_ops_asm.go | 1 + field/koalabear/vector_amd64.go | 1 + 3 files changed, 3 insertions(+) diff --git a/field/babybear/vector_amd64.go b/field/babybear/vector_amd64.go index 4a7fca4b48..742b60ce2e 100644 --- a/field/babybear/vector_amd64.go +++ b/field/babybear/vector_amd64.go @@ -116,6 +116,7 @@ func (vector *Vector) Sum() (res Element) { var t [8]uint64 // stores the accumulators (not reduced mod q) sumVec(&t[0], &(*vector)[0], n/blockSize) // we reduce the accumulators mod q and add to res + var v Element for i := 0; i < 8; i++ { v[0] = uint32(t[i] % q) res.Add(&res, &v) diff --git a/field/generator/internal/templates/element/vector_ops_asm.go b/field/generator/internal/templates/element/vector_ops_asm.go index f012489e51..2a1d0bf8e6 100644 --- a/field/generator/internal/templates/element/vector_ops_asm.go +++ b/field/generator/internal/templates/element/vector_ops_asm.go @@ -356,6 +356,7 @@ func (vector *Vector) Sum() (res {{.ElementName}}) { var t [8]uint64 // stores the accumulators (not reduced mod q) sumVec(&t[0], &(*vector)[0], n/blockSize) // we reduce the accumulators mod q and add to res + var v {{.ElementName}} for i := 0; i < 8; i++ { v[0] = uint32(t[i] % q) res.Add(&res, &v) diff --git a/field/koalabear/vector_amd64.go b/field/koalabear/vector_amd64.go index c3372f1650..ede9d1afa9 100644 --- a/field/koalabear/vector_amd64.go +++ b/field/koalabear/vector_amd64.go @@ -116,6 +116,7 @@ func (vector *Vector) Sum() (res Element) { var t [8]uint64 // stores the accumulators (not reduced mod q) sumVec(&t[0], &(*vector)[0], n/blockSize) // we reduce the accumulators mod q and add to res + var v Element for i := 0; i < 8; i++ { v[0] = uint32(t[i] % q) res.Add(&res, &v)