Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

polygon/heimdall: use generics to eliminate casts #10371

Merged
merged 3 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions polygon/heimdall/checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ import (
"github.com/ledgerwatch/erigon-lib/kv"
)

var _ Waypoint = Checkpoint{}

type CheckpointId uint64

// Checkpoint defines a response object type of bor checkpoint
Expand All @@ -20,50 +18,53 @@ type Checkpoint struct {
Fields WaypointFields
}

func (c Checkpoint) RawId() uint64 {
var _ Entity = &Checkpoint{}
var _ Waypoint = &Checkpoint{}

func (c *Checkpoint) RawId() uint64 {
return uint64(c.Id)
}

func (c Checkpoint) StartBlock() *big.Int {
func (c *Checkpoint) StartBlock() *big.Int {
return c.Fields.StartBlock
}

func (c Checkpoint) EndBlock() *big.Int {
func (c *Checkpoint) EndBlock() *big.Int {
return c.Fields.EndBlock
}

func (c Checkpoint) BlockNumRange() ClosedRange {
func (c *Checkpoint) BlockNumRange() ClosedRange {
return ClosedRange{
Start: c.StartBlock().Uint64(),
End: c.EndBlock().Uint64(),
}
}

func (c Checkpoint) RootHash() libcommon.Hash {
func (c *Checkpoint) RootHash() libcommon.Hash {
return c.Fields.RootHash
}

func (c Checkpoint) Timestamp() uint64 {
func (c *Checkpoint) Timestamp() uint64 {
return c.Fields.Timestamp
}

func (c Checkpoint) Length() uint64 {
func (c *Checkpoint) Length() uint64 {
return c.Fields.Length()
}

func (c Checkpoint) CmpRange(n uint64) int {
func (c *Checkpoint) CmpRange(n uint64) int {
return c.Fields.CmpRange(n)
}

func (m Checkpoint) String() string {
func (c *Checkpoint) String() string {
return fmt.Sprintf(
"Checkpoint {%v (%d:%d) %v %v %v}",
m.Fields.Proposer.String(),
m.Fields.StartBlock,
m.Fields.EndBlock,
m.Fields.RootHash.Hex(),
m.Fields.ChainID,
m.Fields.Timestamp,
c.Fields.Proposer.String(),
c.Fields.StartBlock,
c.Fields.EndBlock,
c.Fields.RootHash.Hex(),
c.Fields.ChainID,
c.Fields.Timestamp,
)
}

Expand Down
34 changes: 17 additions & 17 deletions polygon/heimdall/entity_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,29 @@ import (
"github.com/ledgerwatch/log/v3"
)

type entityFetcher interface {
type entityFetcher[TEntity Entity] interface {
FetchLastEntityId(ctx context.Context) (uint64, error)
FetchEntitiesRange(ctx context.Context, idRange ClosedRange) ([]Entity, error)
FetchEntitiesRange(ctx context.Context, idRange ClosedRange) ([]TEntity, error)
}

type entityFetcherImpl struct {
type entityFetcherImpl[TEntity Entity] struct {
name string

fetchLastEntityId func(ctx context.Context) (int64, error)
fetchEntity func(ctx context.Context, id int64) (Entity, error)
fetchEntitiesPage func(ctx context.Context, page uint64, limit uint64) ([]Entity, error)
fetchEntity func(ctx context.Context, id int64) (TEntity, error)
fetchEntitiesPage func(ctx context.Context, page uint64, limit uint64) ([]TEntity, error)

logger log.Logger
}

func newEntityFetcher(
func newEntityFetcher[TEntity Entity](
name string,
fetchLastEntityId func(ctx context.Context) (int64, error),
fetchEntity func(ctx context.Context, id int64) (Entity, error),
fetchEntitiesPage func(ctx context.Context, page uint64, limit uint64) ([]Entity, error),
fetchEntity func(ctx context.Context, id int64) (TEntity, error),
fetchEntitiesPage func(ctx context.Context, page uint64, limit uint64) ([]TEntity, error),
logger log.Logger,
) entityFetcher {
return &entityFetcherImpl{
) entityFetcher[TEntity] {
return &entityFetcherImpl[TEntity]{
name: name,
fetchLastEntityId: fetchLastEntityId,
fetchEntity: fetchEntity,
Expand All @@ -41,12 +41,12 @@ func newEntityFetcher(
}
}

func (f *entityFetcherImpl) FetchLastEntityId(ctx context.Context) (uint64, error) {
func (f *entityFetcherImpl[TEntity]) FetchLastEntityId(ctx context.Context) (uint64, error) {
id, err := f.fetchLastEntityId(ctx)
return uint64(id), err
}

func (f *entityFetcherImpl) FetchEntitiesRange(ctx context.Context, idRange ClosedRange) ([]Entity, error) {
func (f *entityFetcherImpl[TEntity]) FetchEntitiesRange(ctx context.Context, idRange ClosedRange) ([]TEntity, error) {
count := idRange.Len()

const batchFetchThreshold = 100
Expand All @@ -62,20 +62,20 @@ func (f *entityFetcherImpl) FetchEntitiesRange(ctx context.Context, idRange Clos
return f.FetchEntitiesRangeSequentially(ctx, idRange)
}

func (f *entityFetcherImpl) FetchEntitiesRangeSequentially(ctx context.Context, idRange ClosedRange) ([]Entity, error) {
return ClosedRangeMap(idRange, func(id uint64) (Entity, error) {
func (f *entityFetcherImpl[TEntity]) FetchEntitiesRangeSequentially(ctx context.Context, idRange ClosedRange) ([]TEntity, error) {
return ClosedRangeMap(idRange, func(id uint64) (TEntity, error) {
return f.fetchEntity(ctx, int64(id))
})
}

func (f *entityFetcherImpl) FetchAllEntities(ctx context.Context) ([]Entity, error) {
func (f *entityFetcherImpl[TEntity]) FetchAllEntities(ctx context.Context) ([]TEntity, error) {
// TODO: once heimdall API is fixed to return sorted items in pages we can only fetch
//
// the new pages after lastStoredCheckpointId using the checkpoints/list paging API
// (for now we have to fetch all of them)
// and also remove sorting we do after fetching

var entities []Entity
var entities []TEntity

fetchStartTime := time.Now()
progressLogTicker := time.NewTicker(30 * time.Second)
Expand Down Expand Up @@ -106,7 +106,7 @@ func (f *entityFetcherImpl) FetchAllEntities(ctx context.Context) ([]Entity, err
}
}

slices.SortFunc(entities, func(e1, e2 Entity) int {
slices.SortFunc(entities, func(e1, e2 TEntity) int {
n1 := e1.BlockNumRange().Start
n2 := e2.BlockNumRange().Start
return cmp.Compare(n1, n2)
Expand Down
74 changes: 41 additions & 33 deletions polygon/heimdall/entity_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,23 @@ import (
"github.com/ledgerwatch/erigon-lib/kv/iter"
)

type entityStore interface {
type entityStore[TEntity Entity] interface {
Prepare(ctx context.Context) error
Close()
GetLastEntityId(ctx context.Context) (uint64, bool, error)
GetLastEntity(ctx context.Context) (Entity, error)
GetEntity(ctx context.Context, id uint64) (Entity, error)
PutEntity(ctx context.Context, id uint64, entity Entity) error
FindByBlockNum(ctx context.Context, blockNum uint64) (Entity, error)
RangeFromId(ctx context.Context, startId uint64) ([]Entity, error)
RangeFromBlockNum(ctx context.Context, startBlockNum uint64) ([]Entity, error)
GetLastEntity(ctx context.Context) (TEntity, error)
GetEntity(ctx context.Context, id uint64) (TEntity, error)
PutEntity(ctx context.Context, id uint64, entity TEntity) error
FindByBlockNum(ctx context.Context, blockNum uint64) (TEntity, error)
RangeFromId(ctx context.Context, startId uint64) ([]TEntity, error)
RangeFromBlockNum(ctx context.Context, startBlockNum uint64) ([]TEntity, error)
}

type entityStoreImpl struct {
type entityStoreImpl[TEntity Entity] struct {
tx kv.RwTx
table string

makeEntity func() Entity
makeEntity func() TEntity
getLastEntityId func(ctx context.Context, tx kv.Tx) (uint64, bool, error)
loadEntityBytes func(ctx context.Context, tx kv.Getter, id uint64) ([]byte, error)

Expand All @@ -35,15 +35,15 @@ type entityStoreImpl struct {
prepareOnce sync.Once
}

func newEntityStore(
func newEntityStore[TEntity Entity](
tx kv.RwTx,
table string,
makeEntity func() Entity,
makeEntity func() TEntity,
getLastEntityId func(ctx context.Context, tx kv.Tx) (uint64, bool, error),
loadEntityBytes func(ctx context.Context, tx kv.Getter, id uint64) ([]byte, error),
blockNumToIdIndexFactory func(ctx context.Context) (*RangeIndex, error),
) entityStore {
return &entityStoreImpl{
) entityStore[TEntity] {
return &entityStoreImpl[TEntity]{
tx: tx,
table: table,

Expand All @@ -55,7 +55,7 @@ func newEntityStore(
}
}

func (s *entityStoreImpl) Prepare(ctx context.Context) error {
func (s *entityStoreImpl[TEntity]) Prepare(ctx context.Context) error {
var err error
s.prepareOnce.Do(func() {
s.blockNumToIdIndex, err = s.blockNumToIdIndexFactory(ctx)
Expand All @@ -68,22 +68,30 @@ func (s *entityStoreImpl) Prepare(ctx context.Context) error {
return err
}

func (s *entityStoreImpl) Close() {
func (s *entityStoreImpl[TEntity]) Close() {
s.blockNumToIdIndex.Close()
}

func (s *entityStoreImpl) GetLastEntityId(ctx context.Context) (uint64, bool, error) {
func (s *entityStoreImpl[TEntity]) GetLastEntityId(ctx context.Context) (uint64, bool, error) {
return s.getLastEntityId(ctx, s.tx)
}

func (s *entityStoreImpl) GetLastEntity(ctx context.Context) (Entity, error) {
// Zero value of any type T
// https://stackoverflow.com/questions/70585852/return-default-value-for-generic-type)
// https://go.dev/ref/spec#The_zero_value
func Zero[T any]() T {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! helpful function - it feels like we can add it to erigon-lib/common - we can gradually build common generics helpers if other use cases come along

var value T
return value
}

func (s *entityStoreImpl[TEntity]) GetLastEntity(ctx context.Context) (TEntity, error) {
id, ok, err := s.GetLastEntityId(ctx)
if err != nil {
return nil, err
return Zero[TEntity](), err
}
// not found
if !ok {
return nil, nil
return Zero[TEntity](), nil
}
return s.GetEntity(ctx, id)
}
Expand All @@ -94,28 +102,28 @@ func entityStoreKey(id uint64) [8]byte {
return key
}

func (s *entityStoreImpl) entityUnmarshalJSON(jsonBytes []byte) (Entity, error) {
func (s *entityStoreImpl[TEntity]) entityUnmarshalJSON(jsonBytes []byte) (TEntity, error) {
entity := s.makeEntity()
if err := json.Unmarshal(jsonBytes, entity); err != nil {
return nil, err
return Zero[TEntity](), err
}
return entity, nil
}

func (s *entityStoreImpl) GetEntity(ctx context.Context, id uint64) (Entity, error) {
func (s *entityStoreImpl[TEntity]) GetEntity(ctx context.Context, id uint64) (TEntity, error) {
jsonBytes, err := s.loadEntityBytes(ctx, s.tx, id)
if err != nil {
return nil, err
return Zero[TEntity](), err
}
// not found
if jsonBytes == nil {
return nil, nil
return Zero[TEntity](), nil
}

return s.entityUnmarshalJSON(jsonBytes)
}

func (s *entityStoreImpl) PutEntity(ctx context.Context, id uint64, entity Entity) error {
func (s *entityStoreImpl[TEntity]) PutEntity(ctx context.Context, id uint64, entity TEntity) error {
jsonBytes, err := json.Marshal(entity)
if err != nil {
return err
Expand All @@ -131,27 +139,27 @@ func (s *entityStoreImpl) PutEntity(ctx context.Context, id uint64, entity Entit
return s.blockNumToIdIndex.Put(ctx, entity.BlockNumRange(), id)
}

func (s *entityStoreImpl) FindByBlockNum(ctx context.Context, blockNum uint64) (Entity, error) {
func (s *entityStoreImpl[TEntity]) FindByBlockNum(ctx context.Context, blockNum uint64) (TEntity, error) {
id, err := s.blockNumToIdIndex.Lookup(ctx, blockNum)
if err != nil {
return nil, err
return Zero[TEntity](), err
}
// not found
if id == 0 {
return nil, nil
return Zero[TEntity](), nil
}

return s.GetEntity(ctx, id)
}

func (s *entityStoreImpl) RangeFromId(_ context.Context, startId uint64) ([]Entity, error) {
func (s *entityStoreImpl[TEntity]) RangeFromId(_ context.Context, startId uint64) ([]TEntity, error) {
startKey := entityStoreKey(startId)
it, err := s.tx.Range(s.table, startKey[:], nil)
if err != nil {
return nil, err
}

var entities []Entity
var entities []TEntity
for it.HasNext() {
_, jsonBytes, err := it.Next()
if err != nil {
Expand All @@ -167,7 +175,7 @@ func (s *entityStoreImpl) RangeFromId(_ context.Context, startId uint64) ([]Enti
return entities, nil
}

func (s *entityStoreImpl) RangeFromBlockNum(ctx context.Context, startBlockNum uint64) ([]Entity, error) {
func (s *entityStoreImpl[TEntity]) RangeFromBlockNum(ctx context.Context, startBlockNum uint64) ([]TEntity, error) {
id, err := s.blockNumToIdIndex.Lookup(ctx, startBlockNum)
if err != nil {
return nil, err
Expand All @@ -180,11 +188,11 @@ func (s *entityStoreImpl) RangeFromBlockNum(ctx context.Context, startBlockNum u
return s.RangeFromId(ctx, id)
}

func buildBlockNumToIdIndex(
func buildBlockNumToIdIndex[TEntity Entity](
ctx context.Context,
index *RangeIndex,
iteratorFactory func() (iter.KV, error),
entityUnmarshalJSON func([]byte) (Entity, error),
entityUnmarshalJSON func([]byte) (TEntity, error),
) error {
it, err := iteratorFactory()
if err != nil {
Expand Down
Loading
Loading