Skip to content

Commit

Permalink
switch to repr with 1 low bit
Browse files Browse the repository at this point in the history
  • Loading branch information
mitschabaude committed Apr 4, 2024
1 parent 0f93cc2 commit 04f8648
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 106 deletions.
132 changes: 64 additions & 68 deletions src/lib/provable/gadgets/native-curve.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ import { Fp, Fq } from '../../../bindings/crypto/finite-field.js';
import { PallasAffine } from '../../../bindings/crypto/elliptic-curve.js';
import { fieldToField3 } from './comparison.js';
import { Field3, ForeignField } from './foreign-field.js';
import { exists, existsOne } from '../core/exists.js';
import { bit, isConstant, packBits } from './common.js';
import { TupleN } from '../../util/types.js';
import { exists } from '../core/exists.js';
import { bit, isConstant } from './common.js';
import { l } from './range-check.js';
import { createField, getField } from '../core/field-constructor.js';
import {
createBool,
createField,
getField,
} from '../core/field-constructor.js';
import { Snarky } from '../../../snarky.js';
import { Provable } from '../provable.js';
import { MlPair } from '../../ml/base.js';
Expand All @@ -18,12 +21,13 @@ export {
scale,
fieldToShiftedScalar,
field3ToShiftedScalar,
scaleShiftedSplit5,
scaleShifted,
add,
ShiftedScalar,
};

type Point = { x: Field; y: Field };
type ShiftedScalar = { low5: TupleN<Bool, 5>; high250: Field };
type ShiftedScalar = { lowBit: Bool; high254: Field };

/**
* Gadget to scale a point by a scalar, where the scalar is represented as a _native_ Field.
Expand All @@ -42,7 +46,7 @@ function scale(P: Point, s: Field): Point {
let t = fieldToShiftedScalar(s);

// return (t + 2^254)*P = (s - 2^254 + 2^254)*P = s*P
return scaleShiftedSplit5(P, t);
return scaleShifted(P, t);
}

/**
Expand All @@ -52,101 +56,102 @@ function scale(P: Point, s: Field): Point {
* This is the representation we use for scalars, since it can be used as input to `scaleShiftedSplit5()`.
*/
function fieldToShiftedScalar(s: Field): ShiftedScalar {
return field3ToShiftedScalar(fieldToField3(s));
let sBig = fieldToField3(s);

// assert that sBig is canonical mod p, so that we can't add (kp mod q) factors by doing things modulo q
ForeignField.assertLessThan(sBig, Fp.modulus);

return field3ToShiftedScalar(sBig);
}

/**
* Converts a 3-limb bigint to a shifted representation t = s = 2^254 mod q,
* where t is represented as a 5-bit low part and a 250-bit high part.
* Converts a 3-limb bigint to a shifted representation t = s - 2^255 mod q,
* where t is represented as a low bit and a 254-bit high part.
*/
function field3ToShiftedScalar(s: Field3): ShiftedScalar {
// constant case
if (Field3.isConstant(s)) {
let t = Fq.mod(Field3.toBigint(s) - (1n << 254n));
let low5 = createField(t & 0x1fn).toBits(5);
let high250 = createField(t >> 5n);
return { low5: TupleN.fromArray(5, low5), high250 };
let t = Fq.mod(Field3.toBigint(s) - (1n << 255n));
let lowBit = createBool((t & 1n) === 1n);
let high254 = createField(t >> 1n);
return { lowBit, high254 };
}

// compute t = s - 2^254 mod q using foreign field subtraction
let twoTo254 = Field3.from(1n << 254n);
let [t0, t1, t2] = ForeignField.sub(s, twoTo254, Fq.modulus);
// compute t = s - 2^255 mod q using foreign field subtraction
let twoTo255 = Field3.from(Fq.mod(1n << 255n));
let t = ForeignField.sub(s, twoTo255, Fq.modulus);

// split t into 250 high bits and 5 low bits
// => split t0 into [5, 83]
let tLo = exists(5, () => {
let t = t0.toBigInt();
return [bit(t, 0), bit(t, 1), bit(t, 2), bit(t, 3), bit(t, 4)];
// it's necessary to prove that t is canonical -- otherwise its bit representation is ambiguous
ForeignField.assertLessThan(t, Fq.modulus);

let [t0, t1, t2] = t;

// split t into 254 high bits and a low bit
// => split t0 into [1, 87]
let [tLo, tHi0] = exists(2, () => {
let t0_ = t0.toBigInt();
return [bit(t0_, 0), t0_ >> 1n];
});
let tLoBools = TupleN.map(tLo, (x) => x.assertBool());
let tHi0 = existsOne(() => t0.toBigInt() >> 5n);
let tLoBool = tLo.assertBool();

// prove split
// since we know that t0 < 2^88, this proves that t0High < 2^83
packBits(tLo)
.add(tHi0.mul(1n << 5n))
.assertEquals(t0);
// since we know that t0 < 2^88, this proves that t0High < 2^87
tLo.add(tHi0.mul(2n)).assertEquals(t0);

// pack tHi
// proves that tHi is in [0, 2^250)
let tHi = tHi0
.add(t1.mul(1n << (l - 5n)))
.add(t2.mul(1n << (2n * l - 5n)))
.add(t1.mul(1n << (l - 1n)))
.add(t2.mul(1n << (2n * l - 1n)))
.seal();

return { low5: tLoBools, high250: tHi };
return { lowBit: tLoBool, high254: tHi };
}

/**
* Internal helper to compute `(t + 2^254)*P`.
* `t` is expected to be split into 250 high bits (t >> 5) and 5 low bits (t & 0x1f).
* Internal helper to compute `(t + 2^255)*P`.
* `t` is expected to be split into 254 high bits (t >> 1) and a low bit (t & 1).
*
* The gadget proves that `tHi` is in [0, 2^254) but assumes that `tLo` consists of bits.
*
* The gadget proves that `tHi` is in [0, 2^250) but assumes that `tLo` consists of bits.
* Optionally, you can specify a different number of high bits by passing in `numHighBits`.
*/
function scaleShiftedSplit5(
function scaleShifted(
{ x, y }: Point,
{ low5: tLo, high250: tHi }: ShiftedScalar
{ lowBit: tLo, high254: tHi }: ShiftedScalar,
numHighBits = 254
): Point {
// constant case
if (isConstant(x, y, tHi, ...tLo)) {
if (isConstant(x, y, tHi, tLo)) {
let sP = PallasAffine.scale(
PallasAffine.fromNonzero({ x: x.toBigInt(), y: y.toBigInt() }),
Fq.add(packBits(tLo).toBigInt() + (tHi.toBigInt() << 5n), 1n << 254n)
Fq.mod(tLo.toField().toBigInt() + 2n * tHi.toBigInt() + (1n << 255n))
);
return { x: createField(sP.x), y: createField(sP.y) };
}
const Field = getField();
const Point = provable({ x: Field, y: Field });
const zero = createField(0n);

// R = (2*(t >> 5) + 1 + 2^250)P
// also proves that tHi is in [0, 2^250)
let [, RMl] = Snarky.group.scaleFastUnpack(
// R = (2*(t >> 1) + 1 + 2^255)P
// also returns a 255-bit representation of tHi
let [, RMl, [, ...tHiBitsMl]] = Snarky.group.scaleFastUnpack(
[0, x.value, y.value],
[0, tHi.value],
250
255
);
let P = { x, y };
let R = { x: createField(RMl[1]), y: createField(RMl[2]) };
let [t0, t1, t2, t3, t4] = tLo;

// R = t4 ? R : R - P = ((t >> 4) + 2^250)P
R = Provable.if(t4, Point, R, addNonZero(R, negate(P)));

// R = ((t >> 3) + 2^251)P
// R = ((t >> 2) + 2^252)P
// R = ((t >> 1) + 2^253)P
for (let t of [t3, t2, t1]) {
R = addNonZero(R, R);
R = Provable.if(t, Point, addNonZero(R, P), R);
// prove that tHi has only `numHighBits` bits set
for (let i = numHighBits; i < 255; i++) {
createField(tHiBitsMl[i]).assertEquals(zero);
}

// R = (t + 2^254)P
// in the final step, we allow a zero output to make it work for the 0 scalar
R = addNonZero(R, R);
let { result, isInfinity } = add(R, P);
result = Provable.if(isInfinity, Point, { x: zero, y: zero }, result);
R = Provable.if(t0, Point, result, R);
// R = tLo ? R : R - P = (t + 2^255)P
// we also handle a zero R-P result to make scaling work for the 0 scalar
let { result, isInfinity } = add(R, negate(P));
let RmP = Provable.if(isInfinity, Point, { x: zero, y: zero }, result);
R = Provable.if(tLo, Point, R, RmP);

return R;
}
Expand Down Expand Up @@ -194,15 +199,6 @@ function add(g: Point, h: Point): { result: Point; isInfinity: Bool } {
return { result: { x: x3, y: y3 }, isInfinity };
}

/**
* Addition that asserts the result is non-zero.
*/
function addNonZero(g: Point, h: Point) {
let { result, isInfinity } = add(g, h);
isInfinity.assertFalse();
return result;
}

/**
* Negates a point.
*/
Expand Down
4 changes: 2 additions & 2 deletions src/lib/provable/group.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
import { Provable } from './provable.js';
import { Bool } from './bool.js';
import { assert } from '../util/assert.js';
import { add, scaleShiftedSplit5 } from './gadgets/native-curve.js';
import { add, scaleShifted } from './gadgets/native-curve.js';

export { Group };

Expand Down Expand Up @@ -180,7 +180,7 @@ class Group {
let g_proj = Pallas.scale(toProjective(this), scalar.toBigInt());
return fromProjective(g_proj);
} else {
let result = scaleShiftedSplit5(this, scalar);
let result = scaleShifted(this, scalar);
return new Group(result);
}
}
Expand Down
70 changes: 34 additions & 36 deletions src/lib/provable/scalar.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ import { Scalar as SignableFq } from '../../mina-signer/src/curve-bigint.js';
import { Field, checkBitLength } from './field.js';
import { FieldVar } from './core/fieldvar.js';
import { Bool } from './bool.js';
import { TupleN } from '../util/types.js';
import {
ShiftedScalar,
field3ToShiftedScalar,
fieldToShiftedScalar,
} from './gadgets/native-curve.js';
import { isConstant, packBits } from './gadgets/common.js';
import { isConstant } from './gadgets/common.js';
import { Provable } from './provable.js';
import { assert } from '../util/assert.js';
import type { HashInput } from './types/provable-derivers.js';
Expand All @@ -21,20 +21,20 @@ type ScalarConst = [0, bigint];
/**
* Represents a {@link Scalar}.
*/
class Scalar {
class Scalar implements ShiftedScalar {
/**
* We represent a scalar s in shifted form `t = s - 2^254 mod q,
* split into its low 5 bits (t & 0x1f) and high 250 bits (t >> 5).
* The reason is that we can efficiently compute the scalar multiplication `(t + 2^254) * P = s * P`.
* We represent a scalar s in shifted form `t = s - 2^255 mod q,
* split into its low bit (t & 1) and high 254 bits (t >> 1).
* The reason is that we can efficiently compute the scalar multiplication `(t + 2^255) * P = s * P`.
*/
low5: TupleN<Bool, 5>;
high250: Field;
lowBit: Bool;
high254: Field;

static ORDER = Fq.modulus;

private constructor(low5: TupleN<Bool, 5>, high250: Field) {
this.low5 = low5;
this.high250 = high250;
private constructor(lowBit: Bool, high254: Field) {
this.lowBit = lowBit;
this.high254 = high254;
}

/**
Expand All @@ -44,10 +44,10 @@ class Scalar {
*/
static from(s: Scalar | bigint | number | string): Scalar {
if (s instanceof Scalar) return s;
let t = Fq.mod(BigInt(s) - (1n << 254n));
let low5 = new Field(t & 0x1fn).toBits(5);
let high250 = new Field(t >> 5n);
return new Scalar(TupleN.fromArray(5, low5), high250);
let t = Fq.mod(BigInt(s) - (1n << 255n));
let lowBit = new Bool((t & 1n) === 1n);
let high254 = new Field(t >> 1n);
return new Scalar(lowBit, high254);
}

/**
Expand All @@ -56,17 +56,17 @@ class Scalar {
* This is always possible and unambiguous, since the scalar field is larger than the base field.
*/
static fromNativeField(s: Field): Scalar {
let { low5, high250 } = fieldToShiftedScalar(s);
return new Scalar(low5, high250);
let { lowBit, high254 } = fieldToShiftedScalar(s);
return new Scalar(lowBit, high254);
}

/**
* Check whether this {@link Scalar} is a hard-coded constant in the constraint system.
* If a {@link Scalar} is constructed outside provable code, it is a constant.
*/
isConstant() {
let { low5, high250 } = this;
return isConstant(high250, ...low5);
let { lowBit, high254 } = this;
return isConstant(lowBit, high254);
}

/**
Expand All @@ -85,11 +85,9 @@ class Scalar {
* Convert this {@link Scalar} into a bigint
*/
toBigInt() {
let { low5, high250 } = this.toConstant();
return Fq.add(
packBits(low5).toBigInt() + (high250.toBigInt() << 5n),
1n << 254n
);
let { lowBit, high254 } = this.toConstant();
let t = lowBit.toField().toBigInt() + 2n * high254.toBigInt();
return Fq.mod(t + (1n << 255n));
}

/**
Expand All @@ -104,8 +102,8 @@ class Scalar {
let sBig = field3FromBits(bits);

// convert to shifted representation
let { low5, high250 } = field3ToShiftedScalar(sBig);
return new Scalar(low5, high250);
let { lowBit, high254 } = field3ToShiftedScalar(sBig);
return new Scalar(lowBit, high254);
}

/**
Expand Down Expand Up @@ -211,7 +209,7 @@ class Scalar {
* The fields are not constrained to be boolean.
*/
static toFields(x: Scalar) {
return [...x.low5.map((b) => b.toField()), x.high250];
return [x.lowBit.toField(), x.high254];
}

/**
Expand Down Expand Up @@ -239,7 +237,7 @@ class Scalar {
*
*/
static toInput(x: Scalar): HashInput {
return { fields: [x.high250], packed: x.low5.map((f) => [f.toField(), 1]) };
return { fields: [x.high254], packed: [[x.lowBit.toField(), 1]] };
}

/**
Expand All @@ -258,12 +256,12 @@ class Scalar {
*/
static fromFields(fields: Field[]): Scalar {
assert(
fields.length === 6,
`Scalar.fromFields(): expected 6 fields, got ${fields.length}`
fields.length === 2,
`Scalar.fromFields(): expected 2 fields, got ${fields.length}`
);
let low5 = fields.slice(0, 5).map(Bool.Unsafe.fromField);
let high250 = fields[5];
return new Scalar(TupleN.fromArray(5, low5), high250);
let lowBit = Bool.Unsafe.fromField(fields[0]);
let high254 = fields[1];
return new Scalar(lowBit, high254);
}

/**
Expand All @@ -272,18 +270,18 @@ class Scalar {
* Returns the size of this type in {@link Field} elements.
*/
static sizeInFields(): number {
return 6;
return 2;
}

/**
* Part of the {@link Provable} interface.
*/
static check(s: Scalar) {
/**
* It is not necessary to constrain the range of high250, because the only provable operation on Scalar
* It is not necessary to constrain the range of high254, because the only provable operation on Scalar
* which relies on that range is scalar multiplication -- which constrains the range itself.
*/
return s.low5.forEach(Bool.check);
return Bool.check(s.lowBit);
}

// ProvableExtended<Scalar>
Expand Down

0 comments on commit 04f8648

Please sign in to comment.