Skip to content

Commit

Permalink
Fixed issues with UInts when using them with protokit state
Browse files Browse the repository at this point in the history
  • Loading branch information
rpanic committed Jan 25, 2024
1 parent eb58d2a commit 24a6b3d
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 33 deletions.
35 changes: 18 additions & 17 deletions packages/library/src/math/UInt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ const errors = {
};

export abstract class UIntX<This extends UIntX<any>> extends Struct({
value: Field,
NUM_BITS: Number,
value: Field
}) {
public abstract numBits(): number;

protected static readonly assertionFunction: (
bool: Bool,
msg?: string
Expand Down Expand Up @@ -46,14 +47,14 @@ export abstract class UIntX<This extends UIntX<any>> extends Struct({

protected constructor(
value: Field,
bits: number,
private readonly impls: {
creator: (value: Field) => This;
from: (value: Field | This | bigint | number | string) => This;
}
) {
super({ value, NUM_BITS: bits });
super({ value });

const bits = this.numBits();
if (bits % 16 !== 0) {
throw errors.canOnlyCreateMultiplesOf16Bits();
}
Expand Down Expand Up @@ -107,11 +108,11 @@ export abstract class UIntX<This extends UIntX<any>> extends Struct({
);

UIntX.assertionFunction(
q.rangeCheckHelper(this.NUM_BITS).equals(q),
q.rangeCheckHelper(this.numBits()).equals(q),
"Divison overflowing"
);

if (this.NUM_BITS * 2 > 255) {
if (this.numBits() * 2 > 255) {
// Prevents overflows over the finite field boundary for applicable uints
divisor_.assertLessThan(x, "Divisor too large");
q.assertLessThan(x, "Quotient too large");
Expand All @@ -122,7 +123,7 @@ export abstract class UIntX<This extends UIntX<any>> extends Struct({
let r = x.sub(q.mul(divisor_)).seal();

UIntX.assertionFunction(
r.rangeCheckHelper(this.NUM_BITS).equals(r),
r.rangeCheckHelper(this.numBits()).equals(r),
"Divison overflowing, remainder"
);

Expand Down Expand Up @@ -183,7 +184,7 @@ export abstract class UIntX<This extends UIntX<any>> extends Struct({

// Sqrt fits into (NUM_BITS / 2) bits
sqrtField
.rangeCheckHelper(this.NUM_BITS)
.rangeCheckHelper(this.numBits())
.assertEquals(sqrtField, "Sqrt output overflowing");

// Range check included here?
Expand All @@ -195,12 +196,12 @@ export abstract class UIntX<This extends UIntX<any>> extends Struct({
});

rest
.rangeCheckHelper(this.NUM_BITS)
.rangeCheckHelper(this.numBits())
.assertEquals(rest, "Sqrt rest output overflowing");

const square = sqrtField.mul(sqrtField);

if (this.NUM_BITS * 2 > 255) {
if (this.numBits() * 2 > 255) {
square.assertGreaterThan(sqrtField, "Sqrt result overflowing");
}

Expand Down Expand Up @@ -247,13 +248,13 @@ export abstract class UIntX<This extends UIntX<any>> extends Struct({
let yField = this.impls.from(y).value;
let z = this.value.mul(yField);

if (this.NUM_BITS * 2 > 255) {
if (this.numBits() * 2 > 255) {
// Only one should be enough
z.assertGreaterThan(this.value, "Multiplication overflowing");
}

UIntX.assertionFunction(
z.rangeCheckHelper(this.NUM_BITS).equals(z),
z.rangeCheckHelper(this.numBits()).equals(z),
"Multiplication overflowing"
);
return this.impls.creator(z);
Expand All @@ -265,7 +266,7 @@ export abstract class UIntX<This extends UIntX<any>> extends Struct({
public add(y: This | bigint | number) {
let z = this.value.add(this.impls.from(y).value);
UIntX.assertionFunction(
z.rangeCheckHelper(this.NUM_BITS).equals(z),
z.rangeCheckHelper(this.numBits()).equals(z),
"Addition overflowing"
);
return this.impls.creator(z);
Expand All @@ -277,7 +278,7 @@ export abstract class UIntX<This extends UIntX<any>> extends Struct({
public sub(y: This | bigint | number) {
let z = this.value.sub(this.impls.from(y).value);
UIntX.assertionFunction(
z.rangeCheckHelper(this.NUM_BITS).equals(z),
z.rangeCheckHelper(this.numBits()).equals(z),
"Subtraction overflow"
);
return this.impls.creator(z);
Expand All @@ -292,8 +293,8 @@ export abstract class UIntX<This extends UIntX<any>> extends Struct({
}
let xMinusY = this.value.sub(y.value).seal();
let yMinusX = xMinusY.neg();
let xMinusYFits = xMinusY.rangeCheckHelper(this.NUM_BITS).equals(xMinusY);
let yMinusXFits = yMinusX.rangeCheckHelper(this.NUM_BITS).equals(yMinusX);
let yMinusXFits = yMinusX.rangeCheckHelper(this.numBits()).equals(yMinusX);
let xMinusYFits = xMinusY.rangeCheckHelper(this.numBits()).equals(xMinusY);
UIntX.assertionFunction(xMinusYFits.or(yMinusXFits));
// x <= y if y - x fits in 64 bits
return yMinusXFits;
Expand All @@ -316,7 +317,7 @@ export abstract class UIntX<This extends UIntX<any>> extends Struct({
}
let yMinusX = y.value.sub(this.value).seal();
UIntX.assertionFunction(
yMinusX.rangeCheckHelper(this.NUM_BITS).equals(yMinusX),
yMinusX.rangeCheckHelper(this.numBits()).equals(yMinusX),
message
);
}
Expand Down
13 changes: 10 additions & 3 deletions packages/library/src/math/UInt112.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@ import { Field, UInt32, UInt64 } from "o1js";
import { UIntX } from "./UInt";

export class UInt112 extends UIntX<UInt112> {
// eslint-disable-next-line @typescript-eslint/no-magic-numbers
public static NUM_BITS = 112;
public static get NUM_BITS() {
// eslint-disable-next-line @typescript-eslint/no-magic-numbers
return 112;
}

public numBits() {
// eslint-disable-next-line @typescript-eslint/no-magic-numbers
return 112;
}

/**
* Static method to create a {@link UIntX} with value `0`.
Expand Down Expand Up @@ -39,7 +46,7 @@ export class UInt112 extends UIntX<UInt112> {
}

public constructor(value: Field) {
super(value, UInt112.NUM_BITS, {
super(value, {
creator: (x) => new UInt112(x),
from: (x) => UInt112.from(x),
});
Expand Down
11 changes: 9 additions & 2 deletions packages/library/src/math/UInt224.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@ import { Field, UInt32, UInt64 } from "o1js";
import { UInt112 } from "./UInt112";

export class UInt224 extends UIntX<UInt224> {
public static get NUM_BITS() {
// eslint-disable-next-line @typescript-eslint/no-magic-numbers
public static NUM_BITS = 224;
return 224;
}

public numBits() {
// eslint-disable-next-line @typescript-eslint/no-magic-numbers
return 224;
}

/**
* Static method to create a {@link UIntX} with value `0`.
Expand Down Expand Up @@ -39,7 +46,7 @@ export class UInt224 extends UIntX<UInt224> {
}

public constructor(value: Field) {
super(value, UInt224.NUM_BITS, {
super(value, {
creator: (x) => new UInt224(x),
from: (x) => UInt224.from(x),
});
Expand Down
13 changes: 10 additions & 3 deletions packages/library/src/math/UInt32.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@ import { Field, UInt32 as o1UInt32 } from "o1js";
import { UIntX } from "./UInt";

export class UInt32 extends UIntX<UInt32> {
// eslint-disable-next-line @typescript-eslint/no-magic-numbers
public static NUM_BITS = 32;
public static get NUM_BITS() {
// eslint-disable-next-line @typescript-eslint/no-magic-numbers
return 32;
}

public numBits() {
// eslint-disable-next-line @typescript-eslint/no-magic-numbers
return 32;
}

/**
* Static method to create a {@link UIntX} with value `0`.
Expand Down Expand Up @@ -39,7 +46,7 @@ export class UInt32 extends UIntX<UInt32> {
}

public constructor(value: Field) {
super(value, UInt32.NUM_BITS, {
super(value, {
creator: (x) => new UInt32(x),
from: (x) => UInt32.from(x),
});
Expand Down
12 changes: 9 additions & 3 deletions packages/library/src/math/UInt64.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@ import { Field, UInt32, UInt64 as o1UInt64 } from "o1js";
import { UIntX } from "./UInt";

export class UInt64 extends UIntX<UInt64> {
// eslint-disable-next-line @typescript-eslint/no-magic-numbers
public static NUM_BITS = 64;
public static get NUM_BITS() {
// eslint-disable-next-line @typescript-eslint/no-magic-numbers
return 64;
}

public numBits() {
return 64;
}

/**
* Static method to create a {@link UIntX} with value `0`.
Expand Down Expand Up @@ -39,7 +45,7 @@ export class UInt64 extends UIntX<UInt64> {
}

public constructor(value: Field) {
super(value, UInt64.NUM_BITS, {
super(value, {
creator: (x) => new UInt64(x),
from: (x) => UInt64.from(x),
});
Expand Down
19 changes: 14 additions & 5 deletions packages/library/test/math/UInt.test.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import "reflect-metadata";
import { container } from "tsyringe";
import { RuntimeMethodExecutionContext } from "@proto-kit/protocol";
import { RuntimeMethodExecutionContext, State } from "@proto-kit/protocol";
import { beforeEach } from "@jest/globals";
import bigintsqrt from "bigint-isqrt";
import { UInt112, UInt64 } from "../../src";
import { Provable } from "o1js";
import { Field, Provable } from "o1js";

describe("uint112", () => {
const executionContext = container.resolve(RuntimeMethodExecutionContext);
Expand Down Expand Up @@ -78,11 +78,20 @@ describe("uint112", () => {

const uint = Provable.witness(UInt64, () => UInt64.from(5));

const fields = UInt64.toFields(uint)
const fields = UInt64.toFields(uint);

expect(uint.NUM_BITS).toBe(64);
expect(uint.numBits()).toBe(64);
expect(uint.value.toBigInt()).toBe(5n);
expect(fields.length).toBe(1);
expect(fields[0].toBigInt()).toBe(5n);
})
});

it("should work for state", () => {
expect.assertions(1);

// Only a compilation test
const state = State.from(UInt64);

expect(1).toBe(1);
});
});

0 comments on commit 24a6b3d

Please sign in to comment.