Skip to content

Commit

Permalink
Fixed bugs with async state
Browse files Browse the repository at this point in the history
  • Loading branch information
rpanic committed May 16, 2024
1 parent af5586f commit e2feb08
Show file tree
Hide file tree
Showing 44 changed files with 384 additions and 650 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"scripts": {
"dev": "npx lerna run dev",
"build": "npx lerna run build",
"lint": "npx lerna run lint",
"lint": "npx lerna run lint --parallel",
"lint:staged": "eslint",
"test": "npx lerna run test -- --passWithNoTests",
"test:ci": "npx lerna run test -- --passWithNoTests --forceExit",
Expand Down
4 changes: 2 additions & 2 deletions packages/cli/test/chain.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ export class Balances extends RuntimeModule<object> {

@runtimeMethod()
public async getBalance(address: PublicKey): Promise<Option<UInt64>> {
return this.balances.get(address);
return await this.balances.get(address);
}

@runtimeMethod()
public async setBalance(address: PublicKey, balance: UInt64) {
this.balances.set(address, balance);
await this.balances.set(address, balance);
}
}

Expand Down
8 changes: 4 additions & 4 deletions packages/library/src/hooks/TransactionFeeHook.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ interface Balances {
from: PublicKey,
to: PublicKey,
amount: Balance
) => void;
) => Promise<void>;
}

export interface TransactionFeeHookConfig
Expand Down Expand Up @@ -95,8 +95,8 @@ export class TransactionFeeHook extends ProvableTransactionHook<TransactionFeeHo
return this.persistedFeeAnalyzer;
}

public transferFee(from: PublicKeyOption, fee: UInt64) {
this.balances.transfer(
public async transferFee(from: PublicKeyOption, fee: UInt64) {
await this.balances.transfer(
new TokenId(this.config.tokenId),
from.value,
PublicKey.fromBase58(this.config.feeRecipient),
Expand Down Expand Up @@ -144,7 +144,7 @@ export class TransactionFeeHook extends ProvableTransactionHook<TransactionFeeHo
)
);

this.transferFee(
await this.transferFee(
executionData.transaction.sender,
UInt64.Unsafe.fromField(fee.value)
);
Expand Down
42 changes: 25 additions & 17 deletions packages/library/src/runtime/Balances.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {
runtimeModule,
} from "@proto-kit/module";
import { StateMap, assert } from "@proto-kit/protocol";
import { Field, PublicKey, Struct } from "o1js";
import { Field, PublicKey, Struct, Provable } from "o1js";

import { UInt64 } from "../math/UInt64";

Expand Down Expand Up @@ -51,25 +51,32 @@ export class Balances<Config = NoConfig>
Balance
);

public getBalance(tokenId: TokenId, address: PublicKey): Balance {
public async getBalance(
tokenId: TokenId,
address: PublicKey
): Promise<Balance> {
const key = new BalancesKey({ tokenId, address });
const balanceOption = this.balances.get(key);
const balanceOption = await this.balances.get(key);
return Balance.Unsafe.fromField(balanceOption.value.value);
}

public setBalance(tokenId: TokenId, address: PublicKey, amount: Balance) {
public async setBalance(
tokenId: TokenId,
address: PublicKey,
amount: Balance
) {
const key = new BalancesKey({ tokenId, address });
this.balances.set(key, amount);
await this.balances.set(key, amount);
}

public transfer(
public async transfer(
tokenId: TokenId,
from: PublicKey,
to: PublicKey,
amount: Balance
) {
const fromBalance = this.getBalance(tokenId, from);
const toBalance = this.getBalance(tokenId, to);
const fromBalance = await this.getBalance(tokenId, from);
const toBalance = await this.getBalance(tokenId, to);

const fromBalanceIsSufficient = fromBalance.greaterThanOrEqual(amount);

Expand All @@ -78,20 +85,21 @@ export class Balances<Config = NoConfig>
const newFromBalance = fromBalance.sub(amount);
const newToBalance = toBalance.add(amount);

this.setBalance(tokenId, from, newFromBalance);
this.setBalance(tokenId, to, newToBalance);
await this.setBalance(tokenId, from, newFromBalance);
await this.setBalance(tokenId, to, newToBalance);
}

public mint(tokenId: TokenId, address: PublicKey, amount: Balance) {
const balance = this.getBalance(tokenId, address);
public async mint(tokenId: TokenId, address: PublicKey, amount: Balance) {
const balance = await this.getBalance(tokenId, address);
const newBalance = balance.add(amount);
this.setBalance(tokenId, address, newBalance);
await this.setBalance(tokenId, address, newBalance);
}

public burn(tokenId: TokenId, address: PublicKey, amount: Balance) {
const balance = this.getBalance(tokenId, address);
public async burn(tokenId: TokenId, address: PublicKey, amount: Balance) {
const balance = await this.getBalance(tokenId, address);
Provable.log("Balance", balance, amount);
const newBalance = balance.sub(amount);
this.setBalance(tokenId, address, newBalance);
await this.setBalance(tokenId, address, newBalance);
}

@runtimeMethod()
Expand All @@ -103,6 +111,6 @@ export class Balances<Config = NoConfig>
) {
assert(this.transaction.sender.value.equals(from), errors.senderNotFrom());

this.transfer(tokenId, from, to, amount);
await this.transfer(tokenId, from, to, amount);
}
}
34 changes: 34 additions & 0 deletions packages/library/test/math/State.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import "reflect-metadata";
import {
State,
StateServiceProvider,
RuntimeMethodExecutionContext,
RuntimeTransaction,
NetworkState,
} from "@proto-kit/protocol";
import { UInt64, Field } from "o1js";
import { InMemoryStateService } from "@proto-kit/module";
import { container } from "tsyringe";

describe("interop uint <-> state", () => {
it("should deserialize as a correct class instance coming from state", async () => {
const state = new State<UInt64>(UInt64);
const service = new InMemoryStateService();
const provider = new StateServiceProvider();
state.path = Field(0);
state.stateServiceProvider = provider;
provider.setCurrentStateService(service);

await service.set(state.path, [Field(10)]);

const context = container.resolve(RuntimeMethodExecutionContext);
context.setup({
transaction: RuntimeTransaction.dummyTransaction(),
networkState: NetworkState.empty(),
});

const uint = await state.get();
const uint2 = uint.value.add(5);
expect(uint2.toString()).toStrictEqual("15");
});
});
2 changes: 1 addition & 1 deletion packages/module/src/method/runtimeMethod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ function runtimeMethodInternal(options: {

let result: unknown;
try {
result = Reflect.apply(simulatedMethod, this, args);
result = await Reflect.apply(simulatedMethod, this, args);
} finally {
executionContext.afterMethod();
}
Expand Down
8 changes: 4 additions & 4 deletions packages/module/src/state/InMemoryStateService.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import { Field } from "o1js";
import { StateService } from "@proto-kit/protocol";
import { SimpleAsyncStateService } from "@proto-kit/protocol";

/**
* Naive implementation of an in-memory variant of the StateService interface
*/
export class InMemoryStateService implements StateService {
export class InMemoryStateService implements SimpleAsyncStateService {
/**
* This mapping container null values if the specific entry has been deleted.
* This is used by the CachedState service to keep track of deletions
*/
public values: Record<string, Field[] | null> = {};

public get(key: Field): Field[] | undefined {
public async get(key: Field): Promise<Field[] | undefined> {
return this.values[key.toString()] ?? undefined;
}

public set(key: Field, value: Field[] | undefined) {
public async set(key: Field, value: Field[] | undefined) {
if (value === undefined) {
this.values[key.toString()] = null;
} else {
Expand Down
34 changes: 17 additions & 17 deletions packages/module/test/modules/Balances.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
type ProvableStateTransition,
Path,
MethodPublicOutput,
StateService,
SimpleAsyncStateService,
RuntimeMethodExecutionContext,
RuntimeTransaction,
NetworkState,
Expand All @@ -20,19 +20,19 @@ import { Admin } from "./Admin.js";
describe("balances", () => {
let balances: Balances;

let state: StateService;
let state: SimpleAsyncStateService;

let runtime: Runtime<{
Admin: typeof Admin;
Balances: typeof Balances;
}>;

function getStateValue(path: Field | undefined) {
async function getStateValue(path: Field | undefined) {
if (!path) {
throw new Error("Path not found");
}

const stateValue = state.get(path);
const stateValue = await state.get(path);

if (!stateValue) {
throw new Error("stateValue is undefined");
Expand Down Expand Up @@ -82,7 +82,7 @@ describe("balances", () => {

await runtime.zkProgrammable.zkProgram.compile();

balances.getTotalSupply();
await balances.getTotalSupply();

const { result } = executionContext.current();

Expand All @@ -107,15 +107,15 @@ describe("balances", () => {
describe("state transitions", () => {
let stateTransitions: ProvableStateTransition[];

beforeEach(() => {
beforeEach(async () => {
const executionContext = container.resolve(
RuntimeMethodExecutionContext
);
executionContext.setup({
transaction: RuntimeTransaction.dummyTransaction(),
networkState: NetworkState.empty(),
});
balances.getTotalSupply();
await balances.getTotalSupply();

stateTransitions = executionContext
.current()
Expand All @@ -139,13 +139,13 @@ describe("balances", () => {
);
});

it("should produce a from-only state transition", () => {
it("should produce a from-only state transition", async () => {
expect.assertions(3);

const [stateTransition] = stateTransitions;

const value = UInt64.fromFields(
getStateValue(balances.totalSupply.path)
await getStateValue(balances.totalSupply.path)
);
const treeValue = Poseidon.hash(value.toFields());

Expand All @@ -166,7 +166,7 @@ describe("balances", () => {
state.set(balances.totalSupply.path!, undefined);
});

beforeEach(() => {
beforeEach(async () => {
const executionContext = container.resolve(
RuntimeMethodExecutionContext
);
Expand All @@ -175,7 +175,7 @@ describe("balances", () => {
networkState: NetworkState.empty(),
});

balances.getTotalSupply();
await balances.getTotalSupply();

stateTransitions = executionContext
.current()
Expand Down Expand Up @@ -221,7 +221,7 @@ describe("balances", () => {
describe("state transitions", () => {
let stateTransitions: ProvableStateTransition[];

beforeEach(() => {
beforeEach(async () => {
const executionContext = container.resolve(
RuntimeMethodExecutionContext
);
Expand All @@ -230,7 +230,7 @@ describe("balances", () => {
networkState: NetworkState.empty(),
});

balances.setTotalSupply();
await balances.setTotalSupply();

stateTransitions = executionContext
.current()
Expand All @@ -254,12 +254,12 @@ describe("balances", () => {
);
});

it("should produce a from-to state transition", () => {
it("should produce a from-to state transition", async () => {
expect.assertions(4);

const [stateTransition] = stateTransitions;
const fromValue = UInt64.fromFields(
getStateValue(balances.totalSupply.path)
await getStateValue(balances.totalSupply.path)
);
const fromTreeValue = Poseidon.hash(fromValue.toFields());

Expand All @@ -286,7 +286,7 @@ describe("balances", () => {
let stateTransitions: ProvableStateTransition[];
const address = PrivateKey.random().toPublicKey();

beforeEach(() => {
beforeEach(async () => {
const executionContext = container.resolve(
RuntimeMethodExecutionContext
);
Expand All @@ -295,7 +295,7 @@ describe("balances", () => {
networkState: NetworkState.empty(),
});

balances.getBalance(address);
await balances.getBalance(address);

stateTransitions = executionContext
.current()
Expand Down
14 changes: 7 additions & 7 deletions packages/module/test/modules/Balances.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,26 @@ export class Balances extends RuntimeModule<BalancesConfig> {

@runtimeMethod()
public async getTotalSupply() {
this.totalSupply.get();
await this.totalSupply.get();
}

@runtimeMethod()
public async setTotalSupply() {
this.totalSupply.set(UInt64.from(20));
await this.totalSupply.set(UInt64.from(20));
await this.admin.isAdmin(this.transaction.sender.value);
}

@runtimeMethod()
public async getBalance(address: PublicKey) {
this.balances.get(address).orElse(UInt64.zero);
(await this.balances.get(address)).orElse(UInt64.zero);
}

@runtimeMethod()
public async transientState() {
const totalSupply = this.totalSupply.get();
this.totalSupply.set(totalSupply.orElse(UInt64.zero).add(100));
const totalSupply = await this.totalSupply.get();
await this.totalSupply.set(totalSupply.orElse(UInt64.zero).add(100));

const totalSupply2 = this.totalSupply.get();
this.totalSupply.set(totalSupply2.orElse(UInt64.zero).add(100));
const totalSupply2 = await this.totalSupply.get();
await this.totalSupply.set(totalSupply2.orElse(UInt64.zero).add(100));
}
}
Loading

0 comments on commit e2feb08

Please sign in to comment.