diff --git a/src/lib/provable/gadgets/native-curve.ts b/src/lib/provable/gadgets/native-curve.ts index 4998993805..625fade1d4 100644 --- a/src/lib/provable/gadgets/native-curve.ts +++ b/src/lib/provable/gadgets/native-curve.ts @@ -4,11 +4,14 @@ import { Fp, Fq } from '../../../bindings/crypto/finite-field.js'; import { PallasAffine } from '../../../bindings/crypto/elliptic-curve.js'; import { fieldToField3 } from './comparison.js'; import { Field3, ForeignField } from './foreign-field.js'; -import { exists, existsOne } from '../core/exists.js'; -import { bit, isConstant, packBits } from './common.js'; -import { TupleN } from '../../util/types.js'; +import { exists } from '../core/exists.js'; +import { bit, isConstant } from './common.js'; import { l } from './range-check.js'; -import { createField, getField } from '../core/field-constructor.js'; +import { + createBool, + createField, + getField, +} from '../core/field-constructor.js'; import { Snarky } from '../../../snarky.js'; import { Provable } from '../provable.js'; import { MlPair } from '../../ml/base.js'; @@ -18,12 +21,13 @@ export { scale, fieldToShiftedScalar, field3ToShiftedScalar, - scaleShiftedSplit5, + scaleShifted, add, + ShiftedScalar, }; type Point = { x: Field; y: Field }; -type ShiftedScalar = { low5: TupleN; high250: Field }; +type ShiftedScalar = { lowBit: Bool; high254: Field }; /** * Gadget to scale a point by a scalar, where the scalar is represented as a _native_ Field. @@ -42,7 +46,7 @@ function scale(P: Point, s: Field): Point { let t = fieldToShiftedScalar(s); // return (t + 2^254)*P = (s - 2^254 + 2^254)*P = s*P - return scaleShiftedSplit5(P, t); + return scaleShifted(P, t); } /** @@ -52,66 +56,75 @@ function scale(P: Point, s: Field): Point { * This is the representation we use for scalars, since it can be used as input to `scaleShiftedSplit5()`. */ function fieldToShiftedScalar(s: Field): ShiftedScalar { - return field3ToShiftedScalar(fieldToField3(s)); + let sBig = fieldToField3(s); + + // assert that sBig is canonical mod p, so that we can't add (kp mod q) factors by doing things modulo q + ForeignField.assertLessThan(sBig, Fp.modulus); + + return field3ToShiftedScalar(sBig); } /** - * Converts a 3-limb bigint to a shifted representation t = s = 2^254 mod q, - * where t is represented as a 5-bit low part and a 250-bit high part. + * Converts a 3-limb bigint to a shifted representation t = s - 2^255 mod q, + * where t is represented as a low bit and a 254-bit high part. */ function field3ToShiftedScalar(s: Field3): ShiftedScalar { // constant case if (Field3.isConstant(s)) { - let t = Fq.mod(Field3.toBigint(s) - (1n << 254n)); - let low5 = createField(t & 0x1fn).toBits(5); - let high250 = createField(t >> 5n); - return { low5: TupleN.fromArray(5, low5), high250 }; + let t = Fq.mod(Field3.toBigint(s) - (1n << 255n)); + let lowBit = createBool((t & 1n) === 1n); + let high254 = createField(t >> 1n); + return { lowBit, high254 }; } - // compute t = s - 2^254 mod q using foreign field subtraction - let twoTo254 = Field3.from(1n << 254n); - let [t0, t1, t2] = ForeignField.sub(s, twoTo254, Fq.modulus); + // compute t = s - 2^255 mod q using foreign field subtraction + let twoTo255 = Field3.from(Fq.mod(1n << 255n)); + let t = ForeignField.sub(s, twoTo255, Fq.modulus); - // split t into 250 high bits and 5 low bits - // => split t0 into [5, 83] - let tLo = exists(5, () => { - let t = t0.toBigInt(); - return [bit(t, 0), bit(t, 1), bit(t, 2), bit(t, 3), bit(t, 4)]; + // it's necessary to prove that t is canonical -- otherwise its bit representation is ambiguous + ForeignField.assertLessThan(t, Fq.modulus); + + let [t0, t1, t2] = t; + + // split t into 254 high bits and a low bit + // => split t0 into [1, 87] + let [tLo, tHi0] = exists(2, () => { + let t0_ = t0.toBigInt(); + return [bit(t0_, 0), t0_ >> 1n]; }); - let tLoBools = TupleN.map(tLo, (x) => x.assertBool()); - let tHi0 = existsOne(() => t0.toBigInt() >> 5n); + let tLoBool = tLo.assertBool(); // prove split - // since we know that t0 < 2^88, this proves that t0High < 2^83 - packBits(tLo) - .add(tHi0.mul(1n << 5n)) - .assertEquals(t0); + // since we know that t0 < 2^88, this proves that t0High < 2^87 + tLo.add(tHi0.mul(2n)).assertEquals(t0); // pack tHi - // proves that tHi is in [0, 2^250) let tHi = tHi0 - .add(t1.mul(1n << (l - 5n))) - .add(t2.mul(1n << (2n * l - 5n))) + .add(t1.mul(1n << (l - 1n))) + .add(t2.mul(1n << (2n * l - 1n))) .seal(); - return { low5: tLoBools, high250: tHi }; + return { lowBit: tLoBool, high254: tHi }; } /** - * Internal helper to compute `(t + 2^254)*P`. - * `t` is expected to be split into 250 high bits (t >> 5) and 5 low bits (t & 0x1f). + * Internal helper to compute `(t + 2^255)*P`. + * `t` is expected to be split into 254 high bits (t >> 1) and a low bit (t & 1). + * + * The gadget proves that `tHi` is in [0, 2^254) but assumes that `tLo` consists of bits. * - * The gadget proves that `tHi` is in [0, 2^250) but assumes that `tLo` consists of bits. + * Optionally, you can specify a different number of high bits by passing in `numHighBits`. */ -function scaleShiftedSplit5( +function scaleShifted( { x, y }: Point, - { low5: tLo, high250: tHi }: ShiftedScalar + { lowBit: tLo, high254: tHi }: ShiftedScalar, + numHighBits = 254 ): Point { // constant case - if (isConstant(x, y, tHi, ...tLo)) { + if (isConstant(x, y, tHi, tLo)) { let sP = PallasAffine.scale( PallasAffine.fromNonzero({ x: x.toBigInt(), y: y.toBigInt() }), - Fq.add(packBits(tLo).toBigInt() + (tHi.toBigInt() << 5n), 1n << 254n) + Fq.mod(tLo.toField().toBigInt() + 2n * tHi.toBigInt() + (1n << 255n)) ); return { x: createField(sP.x), y: createField(sP.y) }; } @@ -119,34 +132,26 @@ function scaleShiftedSplit5( const Point = provable({ x: Field, y: Field }); const zero = createField(0n); - // R = (2*(t >> 5) + 1 + 2^250)P - // also proves that tHi is in [0, 2^250) - let [, RMl] = Snarky.group.scaleFastUnpack( + // R = (2*(t >> 1) + 1 + 2^255)P + // also returns a 255-bit representation of tHi + let [, RMl, [, ...tHiBitsMl]] = Snarky.group.scaleFastUnpack( [0, x.value, y.value], [0, tHi.value], - 250 + 255 ); let P = { x, y }; let R = { x: createField(RMl[1]), y: createField(RMl[2]) }; - let [t0, t1, t2, t3, t4] = tLo; - // R = t4 ? R : R - P = ((t >> 4) + 2^250)P - R = Provable.if(t4, Point, R, addNonZero(R, negate(P))); - - // R = ((t >> 3) + 2^251)P - // R = ((t >> 2) + 2^252)P - // R = ((t >> 1) + 2^253)P - for (let t of [t3, t2, t1]) { - R = addNonZero(R, R); - R = Provable.if(t, Point, addNonZero(R, P), R); + // prove that tHi has only `numHighBits` bits set + for (let i = numHighBits; i < 255; i++) { + createField(tHiBitsMl[i]).assertEquals(zero); } - // R = (t + 2^254)P - // in the final step, we allow a zero output to make it work for the 0 scalar - R = addNonZero(R, R); - let { result, isInfinity } = add(R, P); - result = Provable.if(isInfinity, Point, { x: zero, y: zero }, result); - R = Provable.if(t0, Point, result, R); + // R = tLo ? R : R - P = (t + 2^255)P + // we also handle a zero R-P result to make scaling work for the 0 scalar + let { result, isInfinity } = add(R, negate(P)); + let RmP = Provable.if(isInfinity, Point, { x: zero, y: zero }, result); + R = Provable.if(tLo, Point, R, RmP); return R; } @@ -194,15 +199,6 @@ function add(g: Point, h: Point): { result: Point; isInfinity: Bool } { return { result: { x: x3, y: y3 }, isInfinity }; } -/** - * Addition that asserts the result is non-zero. - */ -function addNonZero(g: Point, h: Point) { - let { result, isInfinity } = add(g, h); - isInfinity.assertFalse(); - return result; -} - /** * Negates a point. */ diff --git a/src/lib/provable/group.ts b/src/lib/provable/group.ts index 503fe522ea..da79585271 100644 --- a/src/lib/provable/group.ts +++ b/src/lib/provable/group.ts @@ -10,7 +10,7 @@ import { import { Provable } from './provable.js'; import { Bool } from './bool.js'; import { assert } from '../util/assert.js'; -import { add, scaleShiftedSplit5 } from './gadgets/native-curve.js'; +import { add, scaleShifted } from './gadgets/native-curve.js'; export { Group }; @@ -180,7 +180,7 @@ class Group { let g_proj = Pallas.scale(toProjective(this), scalar.toBigInt()); return fromProjective(g_proj); } else { - let result = scaleShiftedSplit5(this, scalar); + let result = scaleShifted(this, scalar); return new Group(result); } } diff --git a/src/lib/provable/scalar.ts b/src/lib/provable/scalar.ts index ce39c53aa8..4304dc60d3 100644 --- a/src/lib/provable/scalar.ts +++ b/src/lib/provable/scalar.ts @@ -3,12 +3,12 @@ import { Scalar as SignableFq } from '../../mina-signer/src/curve-bigint.js'; import { Field, checkBitLength } from './field.js'; import { FieldVar } from './core/fieldvar.js'; import { Bool } from './bool.js'; -import { TupleN } from '../util/types.js'; import { + ShiftedScalar, field3ToShiftedScalar, fieldToShiftedScalar, } from './gadgets/native-curve.js'; -import { isConstant, packBits } from './gadgets/common.js'; +import { isConstant } from './gadgets/common.js'; import { Provable } from './provable.js'; import { assert } from '../util/assert.js'; import type { HashInput } from './types/provable-derivers.js'; @@ -21,20 +21,20 @@ type ScalarConst = [0, bigint]; /** * Represents a {@link Scalar}. */ -class Scalar { +class Scalar implements ShiftedScalar { /** - * We represent a scalar s in shifted form `t = s - 2^254 mod q, - * split into its low 5 bits (t & 0x1f) and high 250 bits (t >> 5). - * The reason is that we can efficiently compute the scalar multiplication `(t + 2^254) * P = s * P`. + * We represent a scalar s in shifted form `t = s - 2^255 mod q, + * split into its low bit (t & 1) and high 254 bits (t >> 1). + * The reason is that we can efficiently compute the scalar multiplication `(t + 2^255) * P = s * P`. */ - low5: TupleN; - high250: Field; + lowBit: Bool; + high254: Field; static ORDER = Fq.modulus; - private constructor(low5: TupleN, high250: Field) { - this.low5 = low5; - this.high250 = high250; + private constructor(lowBit: Bool, high254: Field) { + this.lowBit = lowBit; + this.high254 = high254; } /** @@ -44,10 +44,10 @@ class Scalar { */ static from(s: Scalar | bigint | number | string): Scalar { if (s instanceof Scalar) return s; - let t = Fq.mod(BigInt(s) - (1n << 254n)); - let low5 = new Field(t & 0x1fn).toBits(5); - let high250 = new Field(t >> 5n); - return new Scalar(TupleN.fromArray(5, low5), high250); + let t = Fq.mod(BigInt(s) - (1n << 255n)); + let lowBit = new Bool((t & 1n) === 1n); + let high254 = new Field(t >> 1n); + return new Scalar(lowBit, high254); } /** @@ -56,8 +56,8 @@ class Scalar { * This is always possible and unambiguous, since the scalar field is larger than the base field. */ static fromNativeField(s: Field): Scalar { - let { low5, high250 } = fieldToShiftedScalar(s); - return new Scalar(low5, high250); + let { lowBit, high254 } = fieldToShiftedScalar(s); + return new Scalar(lowBit, high254); } /** @@ -65,8 +65,8 @@ class Scalar { * If a {@link Scalar} is constructed outside provable code, it is a constant. */ isConstant() { - let { low5, high250 } = this; - return isConstant(high250, ...low5); + let { lowBit, high254 } = this; + return isConstant(lowBit, high254); } /** @@ -85,11 +85,9 @@ class Scalar { * Convert this {@link Scalar} into a bigint */ toBigInt() { - let { low5, high250 } = this.toConstant(); - return Fq.add( - packBits(low5).toBigInt() + (high250.toBigInt() << 5n), - 1n << 254n - ); + let { lowBit, high254 } = this.toConstant(); + let t = lowBit.toField().toBigInt() + 2n * high254.toBigInt(); + return Fq.mod(t + (1n << 255n)); } /** @@ -104,8 +102,8 @@ class Scalar { let sBig = field3FromBits(bits); // convert to shifted representation - let { low5, high250 } = field3ToShiftedScalar(sBig); - return new Scalar(low5, high250); + let { lowBit, high254 } = field3ToShiftedScalar(sBig); + return new Scalar(lowBit, high254); } /** @@ -211,7 +209,7 @@ class Scalar { * The fields are not constrained to be boolean. */ static toFields(x: Scalar) { - return [...x.low5.map((b) => b.toField()), x.high250]; + return [x.lowBit.toField(), x.high254]; } /** @@ -239,7 +237,7 @@ class Scalar { * */ static toInput(x: Scalar): HashInput { - return { fields: [x.high250], packed: x.low5.map((f) => [f.toField(), 1]) }; + return { fields: [x.high254], packed: [[x.lowBit.toField(), 1]] }; } /** @@ -258,12 +256,12 @@ class Scalar { */ static fromFields(fields: Field[]): Scalar { assert( - fields.length === 6, - `Scalar.fromFields(): expected 6 fields, got ${fields.length}` + fields.length === 2, + `Scalar.fromFields(): expected 2 fields, got ${fields.length}` ); - let low5 = fields.slice(0, 5).map(Bool.Unsafe.fromField); - let high250 = fields[5]; - return new Scalar(TupleN.fromArray(5, low5), high250); + let lowBit = Bool.Unsafe.fromField(fields[0]); + let high254 = fields[1]; + return new Scalar(lowBit, high254); } /** @@ -272,7 +270,7 @@ class Scalar { * Returns the size of this type in {@link Field} elements. */ static sizeInFields(): number { - return 6; + return 2; } /** @@ -280,10 +278,10 @@ class Scalar { */ static check(s: Scalar) { /** - * It is not necessary to constrain the range of high250, because the only provable operation on Scalar + * It is not necessary to constrain the range of high254, because the only provable operation on Scalar * which relies on that range is scalar multiplication -- which constrains the range itself. */ - return s.low5.forEach(Bool.check); + return Bool.check(s.lowBit); } // ProvableExtended