Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce memory allocations during state transition #5235

Merged
merged 2 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 30 additions & 21 deletions beacon_chain/spec/beaconstate.nim
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import
"."/[eth2_merkleization, forks, signatures, validator]

from std/algorithm import fill
from std/sequtils import anyIt, mapIt
from std/sequtils import anyIt, mapIt, toSeq

from ./datatypes/capella import BeaconState, ExecutionPayloadHeader, Withdrawal

Expand Down Expand Up @@ -381,14 +381,15 @@ proc is_valid_indexed_attestation*(
ok()

# https://github.com/ethereum/consensus-specs/blob/v1.4.0-beta.0/specs/phase0/beacon-chain.md#get_attesting_indices
func get_attesting_indices*(state: ForkyBeaconState,
data: AttestationData,
bits: CommitteeValidatorsBits,
cache: var StateCache): seq[ValidatorIndex] =
iterator get_attesting_indices_iter*(state: ForkyBeaconState,
data: AttestationData,
bits: CommitteeValidatorsBits,
cache: var StateCache): ValidatorIndex =
## Return the set of attesting indices corresponding to ``data`` and ``bits``
## or nothing if `data` is invalid
## This iterator must not be called in functions using a
## ForkedHashedBeaconState due to https://github.com/nim-lang/Nim/issues/18188

var res: seq[ValidatorIndex]
# Can't be an iterator due to https://github.com/nim-lang/Nim/issues/18188
let committee_index = CommitteeIndex.init(data.index)
if committee_index.isErr() or bits.lenu64 != get_beacon_committee_len(
Expand All @@ -398,9 +399,17 @@ func get_attesting_indices*(state: ForkyBeaconState,
for index_in_committee, validator_index in get_beacon_committee(
state, data.slot, committee_index.get(), cache):
if bits[index_in_committee]:
res.add validator_index
yield validator_index

res
# https://github.com/ethereum/consensus-specs/blob/v1.4.0-beta.0/specs/phase0/beacon-chain.md#get_attesting_indices
func get_attesting_indices*(state: ForkyBeaconState,
data: AttestationData,
bits: CommitteeValidatorsBits,
cache: var StateCache): seq[ValidatorIndex] =
## Return the set of attesting indices corresponding to ``data`` and ``bits``
## or nothing if `data` is invalid

toSeq(get_attesting_indices_iter(state, data, bits, cache))

func get_attesting_indices*(state: ForkedHashedBeaconState;
data: AttestationData;
Expand Down Expand Up @@ -432,7 +441,7 @@ proc is_valid_indexed_attestation*(
if not (skipBlsValidation in flags or attestation.signature is TrustedSig):
var
pubkeys = newSeqOfCap[ValidatorPubKey](sigs)
for index in get_attesting_indices(
for index in get_attesting_indices_iter(
state, attestation.data, attestation.aggregation_bits, cache):
pubkeys.add(state.validators[index].pubkey)

Expand Down Expand Up @@ -496,7 +505,7 @@ func check_attestation_index*(
# https://github.com/ethereum/consensus-specs/blob/v1.4.0-beta.0/specs/altair/beacon-chain.md#get_attestation_participation_flag_indices
func get_attestation_participation_flag_indices(
state: altair.BeaconState | bellatrix.BeaconState | capella.BeaconState,
data: AttestationData, inclusion_delay: uint64): seq[int] =
data: AttestationData, inclusion_delay: uint64): set[TimelyFlag] =
## Return the flag indices that are satisfied by an attestation.
let justified_checkpoint =
if data.target.epoch == get_current_epoch(state):
Expand All @@ -517,20 +526,20 @@ func get_attestation_participation_flag_indices(
# Checked by check_attestation()
doAssert is_matching_source

var participation_flag_indices: seq[int]
var participation_flag_indices: set[TimelyFlag]
if is_matching_source and inclusion_delay <= integer_squareroot(SLOTS_PER_EPOCH):
participation_flag_indices.add(TIMELY_SOURCE_FLAG_INDEX)
participation_flag_indices.incl(TIMELY_SOURCE_FLAG_INDEX)
if is_matching_target and inclusion_delay <= SLOTS_PER_EPOCH:
participation_flag_indices.add(TIMELY_TARGET_FLAG_INDEX)
participation_flag_indices.incl(TIMELY_TARGET_FLAG_INDEX)
if is_matching_head and inclusion_delay == MIN_ATTESTATION_INCLUSION_DELAY:
participation_flag_indices.add(TIMELY_HEAD_FLAG_INDEX)
participation_flag_indices.incl(TIMELY_HEAD_FLAG_INDEX)

participation_flag_indices

# https://github.com/ethereum/consensus-specs/blob/v1.4.0-beta.0/specs/deneb/beacon-chain.md#modified-get_attestation_participation_flag_indices
func get_attestation_participation_flag_indices(
state: deneb.BeaconState,
data: AttestationData, inclusion_delay: uint64): seq[int] =
data: AttestationData, inclusion_delay: uint64): set[TimelyFlag] =
## Return the flag indices that are satisfied by an attestation.
let justified_checkpoint =
if data.target.epoch == get_current_epoch(state):
Expand All @@ -551,13 +560,13 @@ func get_attestation_participation_flag_indices(
# Checked by check_attestation
doAssert is_matching_source

var participation_flag_indices: seq[int]
var participation_flag_indices: set[TimelyFlag]
if is_matching_source and inclusion_delay <= integer_squareroot(SLOTS_PER_EPOCH):
participation_flag_indices.add(TIMELY_SOURCE_FLAG_INDEX)
participation_flag_indices.incl(TIMELY_SOURCE_FLAG_INDEX)
if is_matching_target: # [Modified in Deneb:EIP7045]
participation_flag_indices.add(TIMELY_TARGET_FLAG_INDEX)
participation_flag_indices.incl(TIMELY_TARGET_FLAG_INDEX)
if is_matching_head and inclusion_delay == MIN_ATTESTATION_INCLUSION_DELAY:
participation_flag_indices.add(TIMELY_HEAD_FLAG_INDEX)
participation_flag_indices.incl(TIMELY_HEAD_FLAG_INDEX)

participation_flag_indices

Expand Down Expand Up @@ -672,7 +681,7 @@ func get_proposer_reward*(state: ForkyBeaconState,
epoch_participation: var EpochParticipationFlags): uint64 =
let participation_flag_indices = get_attestation_participation_flag_indices(
state, attestation.data, state.slot - attestation.data.slot)
for index in get_attesting_indices(
for index in get_attesting_indices_iter(
state, attestation.data, attestation.aggregation_bits, cache):
let
base_reward = get_base_reward(state, index, base_reward_per_increment)
Expand Down Expand Up @@ -1115,7 +1124,7 @@ func translate_participation(
get_attestation_participation_flag_indices(state, data, inclusion_delay)

# Apply flags to all attesting validators
for index in get_attesting_indices(
for index in get_attesting_indices_iter(
state, data, attestation.aggregation_bits, cache):
for flag_index in participation_flag_indices:
state.previous_epoch_participation[index] =
Expand Down
24 changes: 16 additions & 8 deletions beacon_chain/spec/datatypes/altair.nim
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,19 @@ export base, sets
from ssz_serialization/proofs import GeneralizedIndex
export proofs.GeneralizedIndex

type
TimelyFlag* {.pure.} = enum
TIMELY_SOURCE_FLAG_INDEX
TIMELY_TARGET_FLAG_INDEX
TIMELY_HEAD_FLAG_INDEX

static:
# Verify that ordinals follow spec values (the spec uses these as shifts for bit flags)
# https://github.com/ethereum/consensus-specs/blob/v1.4.0-alpha.3/specs/altair/beacon-chain.md#participation-flag-indices
doAssert ord(TIMELY_SOURCE_FLAG_INDEX) == 0
doAssert ord(TIMELY_TARGET_FLAG_INDEX) == 1
doAssert ord(TIMELY_HEAD_FLAG_INDEX) == 2

const
# https://github.com/ethereum/consensus-specs/blob/v1.4.0-alpha.3/specs/altair/beacon-chain.md#incentivization-weights
TIMELY_SOURCE_WEIGHT* = 14
Expand All @@ -35,8 +48,8 @@ const
PROPOSER_WEIGHT* = 8
WEIGHT_DENOMINATOR* = 64

PARTICIPATION_FLAG_WEIGHTS* =
[TIMELY_SOURCE_WEIGHT, TIMELY_TARGET_WEIGHT, TIMELY_HEAD_WEIGHT]
PARTICIPATION_FLAG_WEIGHTS*: array[TimelyFlag, uint64] =
[uint64 TIMELY_SOURCE_WEIGHT, TIMELY_TARGET_WEIGHT, TIMELY_HEAD_WEIGHT]

# https://github.com/ethereum/consensus-specs/blob/v1.4.0-beta.0/specs/altair/validator.md#misc
TARGET_AGGREGATORS_PER_SYNC_SUBCOMMITTEE* = 16
Expand All @@ -52,11 +65,6 @@ const
CURRENT_SYNC_COMMITTEE_INDEX* = 54.GeneralizedIndex # `current_sync_committee`
NEXT_SYNC_COMMITTEE_INDEX* = 55.GeneralizedIndex # `next_sync_committee`

# https://github.com/ethereum/consensus-specs/blob/v1.4.0-alpha.3/specs/altair/beacon-chain.md#participation-flag-indices
TIMELY_SOURCE_FLAG_INDEX* = 0
TIMELY_TARGET_FLAG_INDEX* = 1
TIMELY_HEAD_FLAG_INDEX* = 2

# https://github.com/ethereum/consensus-specs/blob/v1.4.0-alpha.3/specs/altair/beacon-chain.md#inactivity-penalties
INACTIVITY_SCORE_BIAS* = 4
INACTIVITY_SCORE_RECOVERY_RATE* = 16
Expand Down Expand Up @@ -310,7 +318,7 @@ type
next_sync_committee*: SyncCommittee # [New in Altair]

UnslashedParticipatingBalances* = object
previous_epoch*: array[PARTICIPATION_FLAG_WEIGHTS.len, Gwei]
previous_epoch*: array[TimelyFlag, Gwei]
current_epoch_TIMELY_TARGET*: Gwei
current_epoch*: Gwei # aka total_active_balance

Expand Down
8 changes: 4 additions & 4 deletions beacon_chain/spec/helpers.nim
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,13 @@ func get_seed*(state: ForkyBeaconState, epoch: Epoch, domain_type: DomainType):
state.get_seed(epoch, domain_type, mix)

# https://github.com/ethereum/consensus-specs/blob/v1.4.0-alpha.3/specs/altair/beacon-chain.md#add_flag
func add_flag*(flags: ParticipationFlags, flag_index: int): ParticipationFlags =
let flag = ParticipationFlags(1'u8 shl flag_index)
func add_flag*(flags: ParticipationFlags, flag_index: TimelyFlag): ParticipationFlags =
let flag = ParticipationFlags(1'u8 shl ord(flag_index))
flags or flag

# https://github.com/ethereum/consensus-specs/blob/v1.4.0-alpha.3/specs/altair/beacon-chain.md#has_flag
func has_flag*(flags: ParticipationFlags, flag_index: int): bool =
let flag = ParticipationFlags(1'u8 shl flag_index)
func has_flag*(flags: ParticipationFlags, flag_index: TimelyFlag): bool =
let flag = ParticipationFlags(1'u8 shl ord(flag_index))
(flags and flag) == flag

# https://github.com/ethereum/consensus-specs/blob/v1.4.0-beta.0/specs/altair/light-client/sync-protocol.md#is_sync_committee_update
Expand Down
17 changes: 8 additions & 9 deletions beacon_chain/spec/state_transition_epoch.nim
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func process_attestation(
flags.incl RewardFlags.isPreviousEpochHeadAttester

# Update the cache for all participants
for validator_index in get_attesting_indices(
for validator_index in get_attesting_indices_iter(
state, a.data, a.aggregation_bits, cache):
template v(): untyped = info.validators[validator_index]

Expand Down Expand Up @@ -205,7 +205,7 @@ func get_unslashed_participating_balances*(
state.previous_epoch_participation[validator_index]

if is_active_previous_epoch:
for flag_index in 0 ..< PARTICIPATION_FLAG_WEIGHTS.len:
for flag_index in TimelyFlag:
if has_flag(previous_epoch_participation, flag_index):
res.previous_epoch[flag_index] += validator_effective_balance

Expand All @@ -216,7 +216,7 @@ func get_unslashed_participating_balances*(
TIMELY_TARGET_FLAG_INDEX):
res.current_epoch_TIMELY_TARGET += validator_effective_balance

for flag_index in 0 ..< PARTICIPATION_FLAG_WEIGHTS.len:
for flag_index in TimelyFlag:
res.previous_epoch[flag_index] =
max(EFFECTIVE_BALANCE_INCREMENT, res.previous_epoch[flag_index])

Expand All @@ -230,7 +230,7 @@ func get_unslashed_participating_balances*(
func is_unslashed_participating_index(
state: altair.BeaconState | bellatrix.BeaconState | capella.BeaconState |
deneb.BeaconState,
flag_index: int, epoch: Epoch, validator_index: ValidatorIndex): bool =
flag_index: TimelyFlag, epoch: Epoch, validator_index: ValidatorIndex): bool =
doAssert epoch in [get_previous_epoch(state), get_current_epoch(state)]
# TODO hoist this conditional
let epoch_participation =
Expand Down Expand Up @@ -658,7 +658,7 @@ func get_flag_index_reward*(

# https://github.com/ethereum/consensus-specs/blob/v1.4.0-beta.0/specs/altair/beacon-chain.md#get_flag_index_deltas
func get_unslashed_participating_increment*(
info: altair.EpochInfo | bellatrix.BeaconState, flag_index: int): Gwei =
info: altair.EpochInfo | bellatrix.BeaconState, flag_index: TimelyFlag): Gwei =
info.balances.previous_epoch[flag_index] div EFFECTIVE_BALANCE_INCREMENT

# https://github.com/ethereum/consensus-specs/blob/v1.4.0-alpha.3/specs/altair/beacon-chain.md#get_flag_index_deltas
Expand All @@ -670,14 +670,14 @@ func get_active_increments*(
iterator get_flag_index_deltas*(
state: altair.BeaconState | bellatrix.BeaconState | capella.BeaconState |
deneb.BeaconState,
flag_index: int, base_reward_per_increment: Gwei,
flag_index: TimelyFlag, base_reward_per_increment: Gwei,
info: var altair.EpochInfo, finality_delay: uint64):
(ValidatorIndex, RewardDelta) =
## Return the deltas for a given ``flag_index`` by scanning through the
## participation flags.
let
previous_epoch = get_previous_epoch(state)
weight = PARTICIPATION_FLAG_WEIGHTS[flag_index].uint64 # safe
weight = PARTICIPATION_FLAG_WEIGHTS[flag_index]
unslashed_participating_increments = get_unslashed_participating_increment(
info, flag_index)
active_increments = get_active_increments(info)
Expand All @@ -695,7 +695,6 @@ iterator get_flag_index_deltas*(
of TIMELY_SOURCE_FLAG_INDEX: ParticipationFlag.timelySourceAttester
of TIMELY_TARGET_FLAG_INDEX: ParticipationFlag.timelyTargetAttester
of TIMELY_HEAD_FLAG_INDEX: ParticipationFlag.timelyHeadAttester
else: raiseAssert "Unknown flag index " & $flag_index

info.validators[vidx].flags.incl pflag

Expand Down Expand Up @@ -796,7 +795,7 @@ func process_rewards_and_penalties*(
finality_delay = get_finality_delay(state)

doAssert state.validators.len() == info.validators.len()
for flag_index in 0 ..< PARTICIPATION_FLAG_WEIGHTS.len:
for flag_index in TimelyFlag:
for validator_index, delta in get_flag_index_deltas(
state, flag_index, base_reward_per_increment, info, finality_delay):
info.validators[validator_index].delta.add(delta)
Expand Down
6 changes: 2 additions & 4 deletions ncli/ncli_common.nim
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ proc collectEpochRewardsAndPenalties*(
total_active_balance)
finality_delay = get_finality_delay(state)

for flag_index in 0 ..< PARTICIPATION_FLAG_WEIGHTS.len:
for flag_index in TimelyFlag:
for validator_index, delta in get_flag_index_deltas(
state, flag_index, base_reward_per_increment, info, finality_delay):
template rp: untyped = rewardsAndPenalties[validator_index]
Expand All @@ -302,7 +302,7 @@ proc collectEpochRewardsAndPenalties*(
max_flag_index_reward = get_flag_index_reward(
state, base_reward, active_increments,
unslashed_participating_increment,
PARTICIPATION_FLAG_WEIGHTS[flag_index].uint64,
PARTICIPATION_FLAG_WEIGHTS[flag_index],
finality_delay)

case flag_index
Expand All @@ -315,8 +315,6 @@ proc collectEpochRewardsAndPenalties*(
of TIMELY_HEAD_FLAG_INDEX:
rp.head_outcome = delta.getOutcome
rp.max_head_reward = max_flag_index_reward
else:
raiseAssert(&"Unknown flag index {flag_index}.")

for validator_index, penalty in get_inactivity_penalty_deltas(
cfg, state, info):
Expand Down
7 changes: 3 additions & 4 deletions tests/consensus_spec/altair/test_fixture_rewards.nim
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# beacon_chain
# Copyright (c) 2020-2022 Status Research & Development GmbH
# Copyright (c) 2020-2023 Status Research & Development GmbH
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
Expand Down Expand Up @@ -49,17 +49,16 @@ proc runTest(rewardsDir, identifier: string) =
total_balance = info.balances.current_epoch
base_reward_per_increment = get_base_reward_per_increment(total_balance)

static: doAssert PARTICIPATION_FLAG_WEIGHTS.len == 3
var
flagDeltas2 = [
flagDeltas2: array[TimelyFlag, Deltas] = [
Deltas.init(state[].validators.len),
Deltas.init(state[].validators.len),
Deltas.init(state[].validators.len)]
inactivityPenaltyDeltas2 = Deltas.init(state[].validators.len)

let finality_delay = get_finality_delay(state[])

for flag_index in 0 ..< PARTICIPATION_FLAG_WEIGHTS.len:
for flag_index in TimelyFlag:
for validator_index, delta in get_flag_index_deltas(
state[], flag_index, base_reward_per_increment, info, finality_delay):
if not is_eligible_validator(info.validators[validator_index]):
Expand Down
7 changes: 3 additions & 4 deletions tests/consensus_spec/bellatrix/test_fixture_rewards.nim
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# beacon_chain
# Copyright (c) 2020-2022 Status Research & Development GmbH
# Copyright (c) 2020-2023 Status Research & Development GmbH
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
Expand Down Expand Up @@ -49,17 +49,16 @@ proc runTest(rewardsDir, identifier: string) =
total_balance = info.balances.current_epoch
base_reward_per_increment = get_base_reward_per_increment(total_balance)

static: doAssert PARTICIPATION_FLAG_WEIGHTS.len == 3
var
flagDeltas2 = [
flagDeltas2: array[TimelyFlag, Deltas] = [
Deltas.init(state[].validators.len),
Deltas.init(state[].validators.len),
Deltas.init(state[].validators.len)]
inactivityPenaltyDeltas2 = Deltas.init(state[].validators.len)

let finality_delay = get_finality_delay(state[])

for flag_index in 0 ..< PARTICIPATION_FLAG_WEIGHTS.len:
for flag_index in TimelyFlag:
for validator_index, delta in get_flag_index_deltas(
state[], flag_index, base_reward_per_increment, info, finality_delay):
if not is_eligible_validator(info.validators[validator_index]):
Expand Down
7 changes: 3 additions & 4 deletions tests/consensus_spec/capella/test_fixture_rewards.nim
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# beacon_chain
# Copyright (c) 2020-2022 Status Research & Development GmbH
# Copyright (c) 2020-2023 Status Research & Development GmbH
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
Expand Down Expand Up @@ -49,17 +49,16 @@ proc runTest(rewardsDir, identifier: string) =
total_balance = info.balances.current_epoch
base_reward_per_increment = get_base_reward_per_increment(total_balance)

static: doAssert PARTICIPATION_FLAG_WEIGHTS.len == 3
var
flagDeltas2 = [
flagDeltas2: array[TimelyFlag, Deltas] = [
Deltas.init(state[].validators.len),
Deltas.init(state[].validators.len),
Deltas.init(state[].validators.len)]
inactivityPenaltyDeltas2 = Deltas.init(state[].validators.len)

let finality_delay = get_finality_delay(state[])

for flag_index in 0 ..< PARTICIPATION_FLAG_WEIGHTS.len:
for flag_index in TimelyFlag:
for validator_index, delta in get_flag_index_deltas(
state[], flag_index, base_reward_per_increment, info, finality_delay):
if not is_eligible_validator(info.validators[validator_index]):
Expand Down
Loading