Skip to content

Commit

Permalink
Fix epoch block marshaling (#1688)
Browse files Browse the repository at this point in the history
* Extend transaction tracker to track blocks as well

Its now just called Tracker

* Rename transaction_tracker.go -> tracker.go

* Fix epoch block marshaling and add test for it

Epoch blocks previously were not json unmarshalable by the geth client,
they now are.
  • Loading branch information
piersy authored Sep 28, 2021
1 parent 00a9a9d commit ae053df
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 169 deletions.
28 changes: 28 additions & 0 deletions core/types/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package types

import (
"encoding/json"
"fmt"
"io"
"math/big"
Expand Down Expand Up @@ -183,6 +184,33 @@ func (r *EpochSnarkData) IsEmpty() bool {
return len(r.Signature) == 0
}

// MarshalJSON marshals as JSON.
func (r EpochSnarkData) MarshalJSON() ([]byte, error) {
type EpochSnarkData struct {
Bitmap hexutil.Bytes
Signature hexutil.Bytes
}
var enc EpochSnarkData
enc.Bitmap = r.Bitmap.Bytes()
enc.Signature = r.Signature
return json.Marshal(&enc)
}

// UnmarshalJSON unmarshals from JSON.
func (r *EpochSnarkData) UnmarshalJSON(input []byte) error {
type EpochSnarkData struct {
Bitmap hexutil.Bytes
Signature hexutil.Bytes
}
var dec EpochSnarkData
if err := json.Unmarshal(input, &dec); err != nil {
return err
}
r.Bitmap = new(big.Int).SetBytes(dec.Bitmap)
r.Signature = dec.Signature
return nil
}

// Body is a simple (mutable, non-safe) data container for storing and moving
// a block's data contents (transactions and uncles) together.
type Body struct {
Expand Down
29 changes: 29 additions & 0 deletions e2e_test/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/celo-org/celo-blockchain/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -37,3 +38,31 @@ func TestSendCelo(t *testing.T) {
err = network.AwaitTransactions(ctx, tx)
require.NoError(t, err)
}

// This test is intended to ensure that epoch blocks can be correctly marshalled.
// We previously had an open bug for this https://github.com/celo-org/celo-blockchain/issues/1574
func TestEpochBlockMarshaling(t *testing.T) {
accounts := test.Accounts(1)
gc, ec, err := test.BuildConfig(accounts)
require.NoError(t, err)

// Configure the shortest possible epoch, uptimeLookbackWindow minimum is 3
// and it needs to be < (epoch -2).
ec.Istanbul.Epoch = 6
ec.Istanbul.DefaultLookbackWindow = 3
network, err := test.NewNetwork(accounts, gc, ec)
require.NoError(t, err)
defer network.Shutdown()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()

// Wait for the whole network to process the transaction.
err = network.AwaitBlock(ctx, 6)
require.NoError(t, err)
b := network[0].Tracker.GetProcessedBlock(6)

// Check that epoch snark data was actually unmarshalled, I.E there was
// something there.
assert.True(t, len(b.EpochSnarkData().Signature) > 0)
assert.True(t, b.EpochSnarkData().Bitmap.Uint64() > 0)
}
19 changes: 15 additions & 4 deletions test/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ type Node struct {
Address common.Address
DevKey *ecdsa.PrivateKey
DevAddress common.Address
Tracker *TransactionTracker
Tracker *Tracker
// The transactions that this node has sent.
SentTxs []*types.Transaction
}
Expand Down Expand Up @@ -137,7 +137,7 @@ func NewNode(
Address: validatorAccount.Address,
DevAddress: devAccount.Address,
DevKey: devAccount.PrivateKey,
Tracker: NewTransactionTracker(),
Tracker: NewTracker(),
}

return node, node.Start()
Expand Down Expand Up @@ -297,7 +297,7 @@ func (n *Node) AwaitSentTransactions(ctx context.Context) error {
// ProcessedTxBlock returns the block that the given transaction was processed
// in, nil will be retuned if the transaction has not been processed by this node.
func (n *Node) ProcessedTxBlock(tx *types.Transaction) *types.Block {
return n.Tracker.GetProcessedBlock(tx.Hash())
return n.Tracker.GetProcessedBlockForTx(tx.Hash())
}

// TxFee returns the gas fee for the given transaction.
Expand Down Expand Up @@ -364,7 +364,7 @@ func NewNetwork(accounts *env.AccountsConfig, gc *genesis.Config, ec *eth.Config

genesis, err := genesis.GenerateGenesis(accounts, gc, "../compiled-system-contracts")
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to generate genesis: %v", err)
}

va := accounts.ValidatorAccounts()
Expand Down Expand Up @@ -460,6 +460,17 @@ func (n Network) AwaitTransactions(ctx context.Context, txs ...*types.Transactio
return nil
}

// AwaitBlock ensures that the entire network has processed a block with the given num.
func (n Network) AwaitBlock(ctx context.Context, num uint64) error {
for _, node := range n {
err := node.Tracker.AwaitBlock(ctx, num)
if err != nil {
return err
}
}
return nil
}

// Shutdown closes all nodes in the network, any errors that are encountered are
// printed to stdout.
func (n Network) Shutdown() {
Expand Down
202 changes: 202 additions & 0 deletions test/tracker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
package test

import (
"context"
"errors"
"fmt"
"sync"

ethereum "github.com/celo-org/celo-blockchain"
"github.com/celo-org/celo-blockchain/common"
"github.com/celo-org/celo-blockchain/core/types"
"github.com/celo-org/celo-blockchain/ethclient"
"github.com/celo-org/celo-blockchain/event"
)

var (
errStopped = errors.New("transaction tracker closed")
)

// Tracker tracks processed blocks and transactions through a subscription with
// an ethclient. It provides the ability to check whether blocks or
// transactions have been processed and to wait till those blocks or
// transactions have been processed.
type Tracker struct {
client *ethclient.Client
heads chan *types.Header
sub ethereum.Subscription
wg sync.WaitGroup
// processedTxs maps transaction hashes to the block they were processed in.
processedTxs map[common.Hash]*types.Block
// processedBlocks maps block number to processed blocks.
processedBlocks map[uint64]*types.Block
processedMu sync.Mutex
stopCh chan struct{}
newBlock event.Feed
}

// NewTracker creates a new tracker.
func NewTracker() *Tracker {
return &Tracker{
heads: make(chan *types.Header, 10),
processedTxs: make(map[common.Hash]*types.Block),
processedBlocks: make(map[uint64]*types.Block),
}
}

// GetProcessedTx returns the processed transaction with the given hash or nil
// if the tracker has not seen a processed transaction with the given hash.
func (tr *Tracker) GetProcessedTx(hash common.Hash) *types.Transaction {
tr.processedMu.Lock()
defer tr.processedMu.Unlock()
return tr.processedTxs[hash].Transaction(hash)
}

// GetProcessedBlockForTx returns the block that a transaction with the given
// hash was processed in or nil if the tracker has not seen a processed
// transaction with the given hash.
func (tr *Tracker) GetProcessedBlockForTx(hash common.Hash) *types.Block {
tr.processedMu.Lock()
defer tr.processedMu.Unlock()
return tr.processedTxs[hash]
}

// GetProcessedBlock returns processed block with the given num or nil if the
// tracker has not seen a processed block with that num.
func (tr *Tracker) GetProcessedBlock(num uint64) *types.Block {
tr.processedMu.Lock()
defer tr.processedMu.Unlock()
return tr.processedBlocks[num]
}

// StartTracking subscribes to new head events on the client and starts
// processing the events in a goroutine.
func (tr *Tracker) StartTracking(client *ethclient.Client) error {
if tr.sub != nil {
return errors.New("attempted to start already started tracker")
}
// The subscription client will buffer 20000 notifications before closing
// the subscription, if that happens the Err() chan will return
// ErrSubscriptionQueueOverflow
sub, err := client.SubscribeNewHead(context.Background(), tr.heads)
if err != nil {
return err
}
tr.client = client
tr.sub = sub
tr.stopCh = make(chan struct{})

tr.wg.Add(1)
go func() {
defer tr.wg.Done()
err := tr.track()
if err != nil {
fmt.Printf("track failed with error: %v\n", err)
}
}()
return nil
}

// track reads new heads from the heads channel and for each head retrieves the
// block, places the block in processedBlocks and places the transactions into
// processedTxs. It signals the sub Subscription for each retrieved block.
func (tr *Tracker) track() error {
for {
select {
case h := <-tr.heads:
b, err := tr.client.BlockByHash(context.Background(), h.Hash())
if err != nil {
return err
}
tr.processedMu.Lock()
tr.processedBlocks[b.NumberU64()] = b
// If we have transactions then process them
if len(b.Transactions()) > 0 {
for _, t := range b.Transactions() {
tr.processedTxs[t.Hash()] = b
}
}
tr.processedMu.Unlock()
// signal
tr.newBlock.Send(struct{}{})
case err := <-tr.sub.Err():
// Will be nil if closed by calling Unsubscribe()
return err
case <-tr.stopCh:
return nil
}
}
}

// AwaitTransactions waits for the transactions listed in hashes to be
// processed, it will return the ctx.Err() if ctx expires before all the
// transactions in hashes were processed or ErrStopped if StopTracking is
// called before all the transactions in hashes were processed.
func (tr *Tracker) AwaitTransactions(ctx context.Context, hashes []common.Hash) error {
hashmap := make(map[common.Hash]struct{}, len(hashes))
for i := range hashes {
hashmap[hashes[i]] = struct{}{}
}
condition := func() bool {
for hash := range hashmap {
_, ok := tr.processedTxs[hash]
if ok {
delete(hashmap, hash)
}
}
// If there are no transactions left then they have all been processed.
return len(hashmap) == 0
}
return tr.await(ctx, condition)
}

// AwaitBlock waits for a block with the given num to be processed, it will
// return the ctx.Err() if ctx expires before a block with that number has been
// processed or ErrStopped if StopTracking is called before a block with that
// number is processed.
func (tr *Tracker) AwaitBlock(ctx context.Context, num uint64) error {
condition := func() bool {
return tr.processedBlocks[num] != nil
}
return tr.await(ctx, condition)
}

// await waits for the provided condition to return true, it rechecks the
// condition every time a new block is received by the Tracker. Await returns
// nil when the condition returns true, otherwise it will return ctx.Err() if
// ctx expires before the condition returns true or ErrStopped if StopTracking
// is called before the condition returns true.
func (tr *Tracker) await(ctx context.Context, condition func() bool) error {
ch := make(chan struct{}, 10)
sub := tr.newBlock.Subscribe(ch)
defer sub.Unsubscribe()
for {
tr.processedMu.Lock()
found := condition()
tr.processedMu.Unlock()
// If we found what we are looking for then return.
if found {
return nil
}
select {
case <-ch:
continue
case <-ctx.Done():
return ctx.Err()
case <-tr.stopCh:
return errStopped
}
}
}

// StopTracking shuts down all the goroutines in the tracker.
func (tr *Tracker) StopTracking() error {
if tr.sub == nil {
return errors.New("attempted to stop already stopped tracker")
}
tr.sub.Unsubscribe()
close(tr.stopCh)
tr.wg.Wait()
tr.wg = sync.WaitGroup{}
return nil
}
Loading

0 comments on commit ae053df

Please sign in to comment.