diff --git a/contracts/TransformCustomSize.sol b/contracts/TransformCustomSize.sol new file mode 100644 index 00000000..b0ac35b0 --- /dev/null +++ b/contracts/TransformCustomSize.sol @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8; + +struct OneAndAHalfSlot { + uint256 x; + uint128 y; +} + +contract SizeDefault { + /// @custom:oz-upgrades-unsafe-allow state-variable-immutable state-variable-assignment + uint immutable w1 = block.number; + /// @custom:oz-upgrades-unsafe-allow state-variable-immutable state-variable-assignment + uint immutable w2 = block.timestamp; + uint immutable x; // slot 0 (after conversion to private) + uint constant y = 1; + uint224 z0; // slot 1 + uint256 z1; // slot 2 + uint32 z2; // slot 3 + OneAndAHalfSlot s1; // slot 4&5 + OneAndAHalfSlot s2; // slot 6&7 + uint32 z3; // slot 8 + uint32 z4; // slot 8 + uint32 z5; // slot 8 + uint64[5] a1; // slot 9&10 + uint64[3] a2; // slot 11 + + constructor(uint _x) { + x = _x; + } + // gap should be 38 = 50 - 12 +} + +/// @custom:storage-size 128 +contract SizeOverride { + /// @custom:oz-upgrades-unsafe-allow state-variable-immutable state-variable-assignment + uint immutable w1 = block.number; + /// @custom:oz-upgrades-unsafe-allow state-variable-immutable state-variable-assignment + uint immutable w2 = block.timestamp; + uint immutable x; // slot 0 (after conversion to private) + uint constant y = 1; + uint224 z0; // slot 1 + uint256 z1; // slot 2 + uint32 z2; // slot 3 + OneAndAHalfSlot s1; // slot 4&5 + OneAndAHalfSlot s2; // slot 6&7 + uint32 z3; // slot 8 + uint32 z4; // slot 8 + uint32 z5; // slot 8 + uint64[5] a1; // slot 9&10 + uint64[3] a2; // slot 11 + + constructor(uint _x) { + x = _x; + } + // gap should be 116 = 128 - 12 +} diff --git a/src/transform-0.8.test.ts b/src/transform-0.8.test.ts index eb35085c..e5617270 100644 --- a/src/transform-0.8.test.ts +++ b/src/transform-0.8.test.ts @@ -13,6 +13,7 @@ import { transformConstructor, } from './transformations/transform-constructor'; import { renameInheritdoc } from './transformations/rename-inheritdoc'; +import { addStorageGaps } from './transformations/add-storage-gaps'; const test = _test as TestFn; @@ -54,3 +55,13 @@ test('preserves immutable if allowed', t => { t.context.transform.apply(removeImmutable); t.snapshot(t.context.transform.results()[file]); }); + +test('custom contract size', t => { + const file = 'contracts/TransformCustomSize.sol'; + t.context.transform.apply(transformConstructor); + t.context.transform.apply(removeLeftoverConstructorHead); + t.context.transform.apply(removeStateVarInits); + t.context.transform.apply(removeImmutable); + t.context.transform.apply(addStorageGaps); + t.snapshot(t.context.transform.results()[file]); +}); diff --git a/src/transform-0.8.test.ts.md b/src/transform-0.8.test.ts.md index b5af85b7..1b7b33f2 100644 --- a/src/transform-0.8.test.ts.md +++ b/src/transform-0.8.test.ts.md @@ -102,3 +102,87 @@ Generated by [AVA](https://avajs.dev). }␊ }␊ ` + +## custom contract size + +> Snapshot 1 + + `// SPDX-License-Identifier: UNLICENSED␊ + pragma solidity ^0.8;␊ + ␊ + struct OneAndAHalfSlot {␊ + uint256 x;␊ + uint128 y;␊ + }␊ + ␊ + contract SizeDefault {␊ + /// @custom:oz-upgrades-unsafe-allow state-variable-immutable state-variable-assignment␊ + uint immutable w1 = block.number;␊ + /// @custom:oz-upgrades-unsafe-allow state-variable-immutable state-variable-assignment␊ + uint immutable w2 = block.timestamp;␊ + uint x; // slot 0 (after conversion to private)␊ + uint constant y = 1;␊ + uint224 z0; // slot 1␊ + uint256 z1; // slot 2␊ + uint32 z2; // slot 3␊ + OneAndAHalfSlot s1; // slot 4&5␊ + OneAndAHalfSlot s2; // slot 6&7␊ + uint32 z3; // slot 8␊ + uint32 z4; // slot 8␊ + uint32 z5; // slot 8␊ + uint64[5] a1; // slot 9&10␊ + uint64[3] a2; // slot 11␊ + ␊ + function __SizeDefault_init(uint _x) internal onlyInitializing {␊ + __SizeDefault_init_unchained(_x);␊ + }␊ + ␊ + function __SizeDefault_init_unchained(uint _x) internal onlyInitializing {␊ + x = _x;␊ + }␊ + // gap should be 38 = 50 - 12␊ + ␊ + /**␊ + * @dev This empty reserved space is put in place to allow future versions to add new␊ + * variables without shifting down storage in the inheritance chain.␊ + * See https://docs.openzeppelin.com/contracts/4.x/upgradeable#storage_gaps␊ + */␊ + uint256[38] private __gap;␊ + }␊ + ␊ + /// @custom:storage-size 128␊ + contract SizeOverride {␊ + /// @custom:oz-upgrades-unsafe-allow state-variable-immutable state-variable-assignment␊ + uint immutable w1 = block.number;␊ + /// @custom:oz-upgrades-unsafe-allow state-variable-immutable state-variable-assignment␊ + uint immutable w2 = block.timestamp;␊ + uint x; // slot 0 (after conversion to private)␊ + uint constant y = 1;␊ + uint224 z0; // slot 1␊ + uint256 z1; // slot 2␊ + uint32 z2; // slot 3␊ + OneAndAHalfSlot s1; // slot 4&5␊ + OneAndAHalfSlot s2; // slot 6&7␊ + uint32 z3; // slot 8␊ + uint32 z4; // slot 8␊ + uint32 z5; // slot 8␊ + uint64[5] a1; // slot 9&10␊ + uint64[3] a2; // slot 11␊ + ␊ + function __SizeOverride_init(uint _x) internal onlyInitializing {␊ + __SizeOverride_init_unchained(_x);␊ + }␊ + ␊ + function __SizeOverride_init_unchained(uint _x) internal onlyInitializing {␊ + x = _x;␊ + }␊ + // gap should be 116 = 128 - 12␊ + ␊ + /**␊ + * @dev This empty reserved space is put in place to allow future versions to add new␊ + * variables without shifting down storage in the inheritance chain.␊ + * See https://docs.openzeppelin.com/contracts/4.x/upgradeable#storage_gaps␊ + */␊ + uint256[116] private __gap;␊ + }␊ + ` diff --git a/src/transform-0.8.test.ts.snap b/src/transform-0.8.test.ts.snap index 0843f66e..c03055dd 100644 Binary files a/src/transform-0.8.test.ts.snap and b/src/transform-0.8.test.ts.snap differ diff --git a/src/transformations/add-storage-gaps.ts b/src/transformations/add-storage-gaps.ts index b096622b..4d2760de 100644 --- a/src/transformations/add-storage-gaps.ts +++ b/src/transformations/add-storage-gaps.ts @@ -1,14 +1,17 @@ -import { SourceUnit, ContractDefinition } from 'solidity-ast'; -import { findAll } from 'solidity-ast/utils'; +import { SourceUnit, ContractDefinition, VariableDeclaration } from 'solidity-ast'; +import { findAll, isNodeType } from 'solidity-ast/utils'; import { formatLines } from './utils/format-lines'; +import { hasOverride } from '../utils/upgrades-overrides'; import { getNodeBounds } from '../solc/ast-utils'; import { StorageLayout } from '../solc/input-output'; import { Transformation } from './type'; import { TransformerTools } from '../transform'; +import { extractNatspec } from '../utils/extractNatspec'; +import { decodeTypeIdentifier } from '../utils/type-id'; -// 100 slots of 32 contractSize each -const TARGET_SIZE = 32 * 50; +// By default, make the contract a total of 50 slots (storage + gap) +const DEFAULT_SLOT_COUNT = 50; export function* addStorageGaps( sourceUnit: SourceUnit, @@ -16,7 +19,20 @@ export function* addStorageGaps( ): Generator { for (const contract of findAll('ContractDefinition', sourceUnit)) { if (contract.contractKind === 'contract') { - const gapSize = getGapSize(contract, getLayout(contract)); + let targetSlots = DEFAULT_SLOT_COUNT; + for (const entry of extractNatspec(contract)) { + if (entry.title === 'custom' && entry.tag === 'storage-size') { + targetSlots = parseInt(entry.args); + } + } + + const gapSize = targetSlots - getContractSlotCount(contract, getLayout(contract)); + + if (gapSize <= 0) { + throw new Error( + `Contract ${contract.name} uses more than the ${targetSlots} reserved slots.`, + ); + } const contractBounds = getNodeBounds(contract); const start = contractBounds.start + contractBounds.length - 1; @@ -43,24 +59,63 @@ export function* addStorageGaps( } } -function getGapSize(contractNode: ContractDefinition, layout: StorageLayout): number { - const varIds = new Set([...findAll('VariableDeclaration', contractNode)].map(v => v.id)); +function isStorageVariable(varDecl: VariableDeclaration): boolean { + switch (varDecl.mutability) { + case 'constant': + return false; + case 'immutable': + return !hasOverride(varDecl, 'state-variable-immutable'); + default: + return true; + } +} - if (layout === undefined) { - throw new Error('Storage layout is needed for this transformation'); +function getNumberOfBytesOfValueType(type: string) { + const details = type.match(/^t_(?[a-z]+)(?[\d]+)?$/); + switch (details?.groups?.base) { + case 'bool': + case 'byte': + return 1; + case 'address': + return 20; + case 'bytes': + return parseInt(details.groups.size, 10); + case 'int': + case 'uint': + return parseInt(details.groups.size, 10) / 8; + default: + throw new Error(`Unsupported value type: ${type}`); } +} + +function getContractSlotCount(contractNode: ContractDefinition, layout: StorageLayout): number { + // This tracks both slot and offset: + // - slot = Math.floor(contractSizeInBytes / 32) + // - offset = contractSizeInBytes % 32 + let contractSizeInBytes = 0; - const local = layout.storage.filter(l => varIds.has(l.astId)); + // don't use `findAll` here, we don't want to go recursive + for (const varDecl of contractNode.nodes.filter(isNodeType('VariableDeclaration'))) { + if (isStorageVariable(varDecl)) { + // try get type details + const typeIdentifier = decodeTypeIdentifier(varDecl.typeDescriptions.typeIdentifier ?? ''); + const type = layout.types?.[typeIdentifier]; - let contractSize = 0; + // size of current object from type details, or try to reconstruct it if + // they're not available try to reconstruct it, which can happen for + // immutable variables + const size = type + ? parseInt(type.numberOfBytes, 10) + : getNumberOfBytesOfValueType(typeIdentifier); - for (const l of local) { - const type = layout.types?.[l.type]; - if (type === undefined) { - throw new Error(`Missing type information for ${type}`); + // used space in the current slot + const offset = contractSizeInBytes % 32; + // remaining space in the current slot (only if slot is dirty) + const remaining = (32 - offset) % 32; + // if the remaining space is not enough to fit the current object, then consume the free space to start at next slot + contractSizeInBytes += (size > remaining ? remaining : 0) + size; } - contractSize += parseInt(type.numberOfBytes, 10); } - return Math.floor((TARGET_SIZE - contractSize) / 32); + return Math.ceil(contractSizeInBytes / 32); } diff --git a/src/utils/extractNatspec.ts b/src/utils/extractNatspec.ts new file mode 100644 index 00000000..0ee598b0 --- /dev/null +++ b/src/utils/extractNatspec.ts @@ -0,0 +1,28 @@ +import { StructuredDocumentation } from 'solidity-ast'; +import { execall } from './execall'; + +interface NatspecTag { + title: string; + tag: string; + args: string; +} + +export function* extractNatspec(node: { + documentation?: string | StructuredDocumentation | null; +}): Generator { + const doc = + typeof node.documentation === 'string' ? node.documentation : node.documentation?.text ?? ''; + + for (const { groups } of execall( + /^\s*(?:@(?\w+)(?::(?<tag>[a-z][a-z-]*))? )?(?<args>(?:(?!^\s@\w+)[^])*)/m, + doc, + )) { + if (groups) { + yield { + title: groups.title ?? '', + tag: groups.tag ?? '', + args: groups.args ?? '', + }; + } + } +} diff --git a/src/utils/type-id.ts b/src/utils/type-id.ts new file mode 100644 index 00000000..46bf8386 --- /dev/null +++ b/src/utils/type-id.ts @@ -0,0 +1,57 @@ +import assert from 'assert'; + +// Type Identifiers in the AST are for some reason encoded so that they don't +// contain parentheses or commas, which have been substituted as follows: +// ( -> $_ +// ) -> _$ +// , -> _$_ +// This is particularly hard to decode because it is not a prefix-free code. +// Thus, the following regex has to perform a lookahead to make sure it gets +// the substitution right. +export function decodeTypeIdentifier(typeIdentifier: string): string { + return typeIdentifier.replace(/(\$_|_\$_|_\$)(?=(\$_|_\$_|_\$)*([^_$]|$))/g, m => { + switch (m) { + case '$_': + return '('; + case '_$': + return ')'; + case '_$_': + return ','; + default: + throw new Error('Unreachable'); + } + }); +} + +// Some Type Identifiers contain a _storage_ptr suffix, but the _ptr part +// appears in some places and not others. We remove it to get consistent type +// ids from the different places in the AST. +export function normalizeTypeIdentifier(typeIdentifier: string): string { + return decodeTypeIdentifier(typeIdentifier).replace(/_storage_ptr\b/g, '_storage'); +} + +// Type Identifiers contain AST id numbers, which makes them sensitive to +// unrelated changes in the source code. This function stabilizes a type +// identifier by removing all AST ids. +export function stabilizeTypeIdentifier(typeIdentifier: string): string { + let decoded = decodeTypeIdentifier(typeIdentifier); + const re = /(t_struct|t_enum|t_contract)\(/g; + let match; + while ((match = re.exec(decoded))) { + let i; + let d = 1; + for (i = match.index + match[0].length; d !== 0; i++) { + assert(i < decoded.length, 'index out of bounds'); + const c = decoded[i]; + if (c === '(') { + d += 1; + } else if (c === ')') { + d -= 1; + } + } + const re2 = /\d+_?/y; + re2.lastIndex = i; + decoded = decoded.replace(re2, ''); + } + return decoded; +}