diff --git a/__tests__/utils/__snapshots__/ellipticalCurve.test.ts.snap b/__tests__/utils/__snapshots__/ellipticalCurve.test.ts.snap index c540f5f84..74ce57b39 100644 --- a/__tests__/utils/__snapshots__/ellipticalCurve.test.ts.snap +++ b/__tests__/utils/__snapshots__/ellipticalCurve.test.ts.snap @@ -1,3 +1,5 @@ // Jest Snapshot v1, https://goo.gl/fbAQLP exports[`pedersen() 1`] = `"0x5ed2703dfdb505c587700ce2ebfcab5b3515cd7e6114817e6026ec9d4b364ca"`; + +exports[`pedersen() with 0 1`] = `"0x1a5c561f97b52c17a19f34c4499a745cd4e8412e29e4ed5e91e4481c7d53935"`; diff --git a/__tests__/utils/ellipticalCurve.test.ts b/__tests__/utils/ellipticalCurve.test.ts index 21ed1339c..2bba8d4ec 100644 --- a/__tests__/utils/ellipticalCurve.test.ts +++ b/__tests__/utils/ellipticalCurve.test.ts @@ -25,6 +25,11 @@ test('pedersen()', () => { expect(own).toMatchSnapshot(); }); +test('pedersen() with 0', () => { + const own = pedersen(['0x12773', '0x0']); + expect(own).toMatchSnapshot(); +}); + test('computeHashOnElements()', () => { const array = ['1', '2', '3', '4']; expect(computeHashOnElements(array)).toBe( diff --git a/src/account/session.ts b/src/account/session.ts index 09e22be0e..2e1b2bfca 100644 --- a/src/account/session.ts +++ b/src/account/session.ts @@ -1,3 +1,5 @@ +import assert from 'minimalistic-assert'; + import { ZERO } from '../constants'; import { ProviderInterface, ProviderOptions } from '../provider'; import { SignerInterface } from '../signer'; @@ -12,8 +14,9 @@ import { KeyPair, } from '../types'; import { feeTransactionVersion, transactionVersion } from '../utils/hash'; +import { MerkleTree } from '../utils/merkle'; import { BigNumberish, toBN } from '../utils/number'; -import type { SignedSession } from '../utils/session'; +import { SignedSession, createMerkleTreeForPolicies } from '../utils/session'; import { compileCalldata, estimatedFeeToMaxFee } from '../utils/stark'; import { fromCallsToExecuteCalldataWithNonce } from '../utils/transaction'; import { Account } from './default'; @@ -23,6 +26,8 @@ const SESSION_PLUGIN_CLASS_HASH = '0x6a184757e350de1fe3a544037efbef6434724980a572f294c90555dadc20052'; export class SessionAccount extends Account implements AccountInterface { + public merkleTree: MerkleTree; + constructor( providerOrOptions: ProviderOptions | ProviderInterface, address: string, @@ -30,6 +35,8 @@ export class SessionAccount extends Account implements AccountInterface { public signedSession: SignedSession ) { super(providerOrOptions, address, keyPairOrSigner); + this.merkleTree = createMerkleTreeForPolicies(signedSession.policies); + assert(signedSession.root === this.merkleTree.root, 'Invalid session'); } private async sessionToCall(session: SignedSession): Promise { diff --git a/src/utils/hash.ts b/src/utils/hash.ts index 3c6041371..c9f6dab69 100644 --- a/src/utils/hash.ts +++ b/src/utils/hash.ts @@ -62,13 +62,15 @@ export function pedersen(input: [BigNumberish, BigNumberish]) { for (let i = 0; i < input.length; i += 1) { let x = toBN(input[i]); assert(x.gte(ZERO) && x.lt(toBN(addHexPrefix(FIELD_PRIME))), `Invalid input: ${input[i]}`); - for (let j = 0; j < 252; j += 1) { - const pt = constantPoints[2 + i * 252 + j]; - assert(!point.getX().eq(pt.getX())); - if (x.and(ONE).toNumber() !== 0) { - point = point.add(pt); + if (!x.isZero()) { + for (let j = 0; j < 252; j += 1) { + const pt = constantPoints[2 + i * 252 + j]; + assert(!point.getX().eq(pt.getX())); + if (x.and(ONE).toNumber() !== 0) { + point = point.add(pt); + } + x = x.shrn(1); } - x = x.shrn(1); } } return addHexPrefix(point.getX().toString(16)); diff --git a/src/utils/session.ts b/src/utils/session.ts index c756a8983..baebc22ce 100644 --- a/src/utils/session.ts +++ b/src/utils/session.ts @@ -9,16 +9,13 @@ interface Policy { selector: string; } -interface BaseSession { +export interface RequestSession { key: string; expires: number; -} - -export interface RequestSession extends BaseSession { policies: Policy[]; } -export interface PreparedSession extends BaseSession { +export interface PreparedSession extends RequestSession { root: string; } @@ -30,10 +27,13 @@ function preparePolicy({ contractAddress, selector }: Policy): string { return pedersen([contractAddress, prepareSelector(selector)]); } +export function createMerkleTreeForPolicies(policies: Policy[]): MerkleTree { + return new MerkleTree(policies.map(preparePolicy)); +} + export function prepareSession(session: RequestSession): PreparedSession { - const { policies, ...rest } = session; - const { root } = new MerkleTree(policies.map(preparePolicy)); - return { ...rest, root }; + const { root } = createMerkleTreeForPolicies(session.policies); + return { ...session, root }; } export async function createSession( @@ -41,7 +41,7 @@ export async function createSession( account: AccountInterface, domain: StarkNetDomain = {} ): Promise { - const { key, expires, root } = prepareSession(session); + const { expires, key, policies, root } = prepareSession(session); const signature = await account.signMessage({ primaryType: 'Session', types: { @@ -69,6 +69,7 @@ export async function createSession( }); return { key, + policies, expires, root, signature,