From 6a0ea846bb26c4b41cb65cbbc1451555a769f481 Mon Sep 17 00:00:00 2001 From: Matthias Fasching <5011972+fasmat@users.noreply.github.com> Date: Mon, 4 Mar 2024 14:53:03 +0000 Subject: [PATCH] Move node key to config directory and enable loading of multiple identities (#5592) ## Motivation This completes the implementation of multi-smeshing support. Related: https://github.com/spacemeshos/post/pull/270 --- CHANGELOG.md | 77 ++- activation/activation.go | 72 ++- activation/activation_multi_test.go | 50 +- activation/activation_test.go | 152 +++++- activation/e2e/activation_test.go | 20 +- activation/e2e/nipost_test.go | 25 +- activation/e2e/validation_test.go | 6 +- activation/handler.go | 5 +- activation/interface.go | 4 +- activation/mocks.go | 24 +- activation/post.go | 25 +- activation/post_supervisor.go | 7 +- activation/post_supervisor_test.go | 55 ++- activation/post_test.go | 163 +++---- activation/post_verifier.go | 4 +- activation/validation_test.go | 12 +- api/grpcserver/grpcserver_test.go | 56 ++- api/grpcserver/interface.go | 2 +- api/grpcserver/mocks.go | 12 +- api/grpcserver/post_service_test.go | 14 +- api/grpcserver/smesher_service.go | 8 +- api/grpcserver/smesher_service_test.go | 48 +- beacon/beacon.go | 10 +- blocks/certifier.go | 10 +- checkpoint/recovery.go | 174 ++++--- checkpoint/recovery_test.go | 431 +++++++++------- checkpoint/runner_test.go | 38 +- checkpoint/util.go | 12 +- checkpoint/util_test.go | 22 +- go.mod | 8 +- go.sum | 16 +- hare3/hare.go | 6 +- malfeasance/handler.go | 9 +- malfeasance/handler_test.go | 24 +- miner/proposal_builder.go | 11 +- node/bad_peer_test.go | 2 - node/node.go | 157 +++--- node/node_identities.go | 209 ++++++++ node/node_identities_test.go | 178 +++++++ node/node_test.go | 459 ++++++++++-------- node/test_network.go | 6 +- proposals/util/util.go | 10 +- systest/tests/checkpoint_test.go | 1 + .../distributed_post_verification_test.go | 3 +- systest/tests/smeshing_test.go | 1 + tortoise/algorithm.go | 4 +- 46 files changed, 1760 insertions(+), 882 deletions(-) create mode 100644 node/node_identities.go create mode 100644 node/node_identities_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index e53afaa7a6..6253e7dbeb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,12 @@ encrypted connection between the post service and the node over insecure connect Smeshers using the default setup with a supervised post service do not need to make changes to their node configuration. +#### Fully migrated local state into `node_state.sql` + +With this release the node has fully migrated its local state into `node_state.sql`. During the first start after the +upgrade the node will migrate the data from disk and store it in the database. This change also allows the PoST data +directory to be set to read only after the migration is complete, as the node will no longer write to it. + #### New poets configuration Upgrading requires changes in config and in CLI flags (if not using the default). @@ -92,17 +98,65 @@ configuration is as follows: } ``` +#### Extend go-spacemesh with option to manage multiple identities/PoST services + +**NOTE:** This is a new feature, not yet supported by Smapp and possibly subject to change. Please use with caution. + +A node can now manage multiple identities and their life cycle. This reduces the amount of data that is needed to be +broadcasted / fetched from the network and reduces the amount of data that needs to be stored locally, because only one +database is needed for all identities instead of one for each. + +To ensure you are eligible for rewards of any given identity, the associated PoST service must be running and connected +to the node during the cyclegap set in the node's configuration. After successfully broadcasting the ATX and registering +at a PoET server the PoST services can be stopped with only the node having to be online. + +This change moves the private keys associated for an identity from the PoST data directory to the node's data directory +and into the folder `identities` (i.e. if `state.sql` is in folder `data` the keys will now be stored in `data/identities`). +The node will automatically migrate the `key.bin` file from the PoST data directory during the first startup and copy +it to the new location as `identity.key`. The content of the file stays unchanged (= the private key of the identity hex-encoded). + +##### Adding new identities/PoST services to a node + +To add a new identity to a node, initialize PoST data with `postcli` and let it generate a new private key for you: + +```shell +./postcli -provider=2 -numUnits=4 -datadir=/path/to/data \ + -commitmentAtxId=c230c51669d1fcd35860131e438e234726b2bd5f9adbbd91bd88a718e7e98ecb +``` + +Make sure to replace `provider` with your provider of choice and `numUnits` with the number of PoST units you want to +initialize. The `commitmentAtxId` is the commitment ATX ID for the identity you want to initialize. For details on the +usage of `postcli` please refer to [postcli README](https://github.com/spacemeshos/post/cmd/postcli/README.md). + +During initialization `postcli` will generate a new private key and store it in the PoST data directory as `key.bin`. +Copy this file to your `data/identities` directory and rename it to `xxx.key` where `xxx` is a unique identifier for +the identity. The node will automatically pick up the new identity and manage its lifecycle after a restart. + +Setup the `post-service` [binary](https://github.com/spacemeshos/post-rs/releases) or +[docker image](https://hub.docker.com/r/spacemeshos/post-service/tags) with the data and configure it to connect to your +node. For details refer to the [post-service README](https://github.com/spacemeshos/post-rs/blob/main/service/README.md). + +##### Migrating existing identities/PoST services to a node + +If you have multiple nodes running and want to migrate to use only one node for all identities: + +1. Stop all nodes. +2. Copy the `key.bin` files from the PoST data directories of all nodes to the data directory of the node you want to + use for both identities and into the folder `data/identities`. Rename the files to `xxx.key` where `xxx` is a unique + identifier for each identity. +3. Start the node managing the identities. +4. For every identity setup a post service to use the existing PoST data for that identity and connect to the node. + For details refer to the [post-service README](https://github.com/spacemeshos/post-rs/blob/main/service/README.md). + +**WARNING:** DO NOT run multiple nodes with the same identity at the same time. This will result in an equivocation +and permanent ineligibility for rewards. + ### Highlights * [#5293](https://github.com/spacemeshos/go-spacemesh/pull/5293) change poet servers configuration The config now takes the poet server address and its public key. See the [Upgrade Information](#new-poets-configuration) for details. -* [#5219](https://github.com/spacemeshos/go-spacemesh/pull/5219) Migrate data from `nipost_builder_state.bin` to `node_state.sql`. - - The node will automatically migrate the data from disk and store it in the database. The migration will take place at the - first startup after the upgrade. - * [#5390](https://github.com/spacemeshos/go-spacemesh/pull/5390) Distributed PoST verification. @@ -111,12 +165,25 @@ configuration is as follows: If a node finds a proof invalid, it will report it to the network by creating a malfeasance proof. The malicious node will then be blacklisted by the network. +* [#5592](https://gihtub.com/spacemeshos/go-spacemesh/pull/5592) + Extend node with option to have multiple PoST services connect. This allows users to run multiple PoST services, + without the need to run multiple nodes. A node can now manage multiple identities and will manage the lifecycle of + those identities. + To collect rewards for every identity, the associated PoST service must be running and connected to the node during + the cyclegap set in the node's configuration. + ### Features ### Improvements +* [#5219](https://github.com/spacemeshos/go-spacemesh/pull/5219) Migrate data from `nipost_builder_state.bin` to `node_state.sql`. + + The node will automatically migrate the data from disk and store it in the database. The migration will take place at the + first startup after the upgrade. + * [#5418](https://github.com/spacemeshos/go-spacemesh/pull/5418) Add `grpc-post-listener` to separate post service from `grpc-private-listener` and not require mTLS for the post service. + * [#5465](https://github.com/spacemeshos/go-spacemesh/pull/5465) Add an option to cache SQL query results. This is useful for nodes with high peer counts. diff --git a/activation/activation.go b/activation/activation.go index 67044a1968..8d866a1609 100644 --- a/activation/activation.go +++ b/activation/activation.go @@ -59,6 +59,7 @@ const ( // Config defines configuration for Builder. type Config struct { GoldenATXID types.ATXID + LabelsPerUnit uint64 RegossipInterval time.Duration } @@ -68,8 +69,7 @@ type Config struct { type Builder struct { accountLock sync.RWMutex coinbaseAccount types.Address - goldenATXID types.ATXID - regossipInterval time.Duration + conf Config cdb *datastore.CachedDB localDB *localsql.Database publisher pubsub.Publisher @@ -143,8 +143,7 @@ func NewBuilder( b := &Builder{ parentCtx: context.Background(), signers: make(map[types.NodeID]*signing.EdSigner), - goldenATXID: conf.GoldenATXID, - regossipInterval: conf.RegossipInterval, + conf: conf, cdb: cdb, localDB: localDB, publisher: publisher, @@ -165,11 +164,11 @@ func (b *Builder) Register(sig *signing.EdSigner) { b.smeshingMutex.Lock() defer b.smeshingMutex.Unlock() if _, exists := b.signers[sig.NodeID()]; exists { - b.log.Error("signing key already registered", zap.Stringer("id", sig.NodeID())) + b.log.Error("signing key already registered", log.ZShortStringer("id", sig.NodeID())) return } - b.log.Info("registered signing key", zap.Stringer("id", sig.NodeID())) + b.log.Info("registered signing key", log.ZShortStringer("id", sig.NodeID())) b.signers[sig.NodeID()] = sig if b.stop != nil { @@ -213,11 +212,11 @@ func (b *Builder) startID(ctx context.Context, sig *signing.EdSigner) { b.run(ctx, sig) return nil }) - if b.regossipInterval == 0 { + if b.conf.RegossipInterval == 0 { return } b.eg.Go(func() error { - ticker := time.NewTicker(b.regossipInterval) + ticker := time.NewTicker(b.conf.RegossipInterval) defer ticker.Stop() for { select { @@ -253,7 +252,7 @@ func (b *Builder) StopSmeshing(deleteFiles bool) error { var resetErr error for _, sig := range b.signers { if err := b.nipostBuilder.ResetState(sig.NodeID()); err != nil { - b.log.Error("failed to reset builder state", log.ZShortStringer("nodeId", sig.NodeID()), zap.Error(err)) + b.log.Error("failed to reset builder state", log.ZShortStringer("id", sig.NodeID()), zap.Error(err)) err = fmt.Errorf("reset builder state for id %s: %w", sig.NodeID().ShortString(), err) resetErr = errors.Join(resetErr, err) continue @@ -277,13 +276,13 @@ func (b *Builder) SmesherIDs() []types.NodeID { return maps.Keys(b.signers) } -func (b *Builder) buildInitialPost(ctx context.Context, nodeId types.NodeID) error { +func (b *Builder) buildInitialPost(ctx context.Context, nodeID types.NodeID) error { // Generate the initial POST if we don't have an ATX... - if _, err := b.cdb.GetLastAtx(nodeId); err == nil { + if _, err := b.cdb.GetLastAtx(nodeID); err == nil { return nil } // ...and if we haven't stored an initial post yet. - _, err := nipost.InitialPost(b.localDB, nodeId) + _, err := nipost.InitialPost(b.localDB, nodeID) switch { case err == nil: b.log.Info("load initial post from db") @@ -296,14 +295,10 @@ func (b *Builder) buildInitialPost(ctx context.Context, nodeId types.NodeID) err // Create the initial post and save it. startTime := time.Now() - post, postInfo, err := b.nipostBuilder.Proof(ctx, nodeId, shared.ZeroChallenge) + post, postInfo, err := b.nipostBuilder.Proof(ctx, nodeID, shared.ZeroChallenge) if err != nil { return fmt.Errorf("post execution: %w", err) } - metrics.PostDuration.Set(float64(time.Since(startTime).Nanoseconds())) - public.PostSeconds.Set(float64(time.Since(startTime))) - b.log.Info("created the initial post") - initialPost := nipost.Post{ Nonce: post.Nonce, Indices: post.Indices, @@ -313,7 +308,23 @@ func (b *Builder) buildInitialPost(ctx context.Context, nodeId types.NodeID) err CommitmentATX: postInfo.CommitmentATX, VRFNonce: *postInfo.Nonce, } - return nipost.AddInitialPost(b.localDB, nodeId, initialPost) + err = b.validator.Post(ctx, nodeID, postInfo.CommitmentATX, post, &types.PostMetadata{ + Challenge: shared.ZeroChallenge, + LabelsPerUnit: postInfo.LabelsPerUnit, + }, postInfo.NumUnits) + if err != nil { + b.log.Error("initial POST is invalid", log.ZShortStringer("smesherID", nodeID), zap.Error(err)) + if err := nipost.RemoveInitialPost(b.localDB, nodeID); err != nil { + b.log.Fatal("failed to remove initial post", log.ZShortStringer("smesherID", nodeID), zap.Error(err)) + } + return fmt.Errorf("initial POST is invalid: %w", err) + } + + metrics.PostDuration.Set(float64(time.Since(startTime).Nanoseconds())) + public.PostSeconds.Set(float64(time.Since(startTime))) + b.log.Info("created the initial post") + + return nipost.AddInitialPost(b.localDB, nodeID, initialPost) } func (b *Builder) run(ctx context.Context, sig *signing.EdSigner) { @@ -379,7 +390,7 @@ func (b *Builder) run(ctx context.Context, sig *signing.EdSigner) { } } -func (b *Builder) buildNIPostChallenge(ctx context.Context, nodeID types.NodeID) (*types.NIPostChallenge, error) { +func (b *Builder) BuildNIPostChallenge(ctx context.Context, nodeID types.NodeID) (*types.NIPostChallenge, error) { select { case <-ctx.Done(): return nil, ctx.Err() @@ -451,6 +462,23 @@ func (b *Builder) buildNIPostChallenge(ctx context.Context, nodeID types.NodeID) if err != nil { return nil, fmt.Errorf("get initial post: %w", err) } + b.log.Info("verifying the initial post") + initialPost := &types.Post{ + Nonce: post.Nonce, + Indices: post.Indices, + Pow: post.Pow, + } + err = b.validator.Post(ctx, nodeID, post.CommitmentATX, initialPost, &types.PostMetadata{ + Challenge: shared.ZeroChallenge, + LabelsPerUnit: b.conf.LabelsPerUnit, + }, post.NumUnits) + if err != nil { + b.log.Error("initial POST is invalid", log.ZShortStringer("smesherID", nodeID), zap.Error(err)) + if err := nipost.RemoveInitialPost(b.localDB, nodeID); err != nil { + b.log.Fatal("failed to remove initial post", log.ZShortStringer("smesherID", nodeID), zap.Error(err)) + } + return nil, fmt.Errorf("initial POST is invalid: %w", err) + } challenge = &types.NIPostChallenge{ PublishEpoch: current + 1, Sequence: 0, @@ -498,7 +526,7 @@ func (b *Builder) Coinbase() types.Address { // PublishActivationTx attempts to publish an atx, it returns an error if an atx cannot be created. func (b *Builder) PublishActivationTx(ctx context.Context, sig *signing.EdSigner) error { - challenge, err := b.buildNIPostChallenge(ctx, sig.NodeID()) + challenge, err := b.BuildNIPostChallenge(ctx, sig.NodeID()) if err != nil { return err } @@ -630,7 +658,7 @@ func (b *Builder) getPositioningAtx(ctx context.Context, nodeID types.NodeID) (t ctx, b.cdb, nodeID, - b.goldenATXID, + b.conf.GoldenATXID, b.validator, b.log, VerifyChainOpts.AssumeValidBefore(time.Now().Add(-b.postValidityDelay)), @@ -639,7 +667,7 @@ func (b *Builder) getPositioningAtx(ctx context.Context, nodeID types.NodeID) (t ) if errors.Is(err, sql.ErrNotFound) { b.log.Info("using golden atx as positioning atx") - return b.goldenATXID, nil + return b.conf.GoldenATXID, nil } return id, err } diff --git a/activation/activation_multi_test.go b/activation/activation_multi_test.go index 38d7fe265d..e464442411 100644 --- a/activation/activation_multi_test.go +++ b/activation/activation_multi_test.go @@ -222,15 +222,33 @@ func TestRegossip(t *testing.T) { func Test_Builder_Multi_InitialPost(t *testing.T) { tab := newTestBuilder(t, 5, WithPoetConfig(PoetConfig{PhaseShift: layerDuration * 4})) + var eg errgroup.Group for _, sig := range tab.signers { sig := sig eg.Go(func() error { + numUnits := uint32(12) + + post := &types.Post{ + Indices: types.RandomBytes(10), + Nonce: rand.Uint32(), + Pow: rand.Uint64(), + } + meta := &types.PostMetadata{ + Challenge: shared.ZeroChallenge, + LabelsPerUnit: tab.conf.LabelsPerUnit, + } + + commitmentATX := types.RandomATXID() + tab.mValidator.EXPECT().Post(gomock.Any(), sig.NodeID(), commitmentATX, post, meta, numUnits).Return(nil) tab.mnipost.EXPECT().Proof(gomock.Any(), sig.NodeID(), shared.ZeroChallenge).Return( - &types.Post{Indices: make([]byte, 10)}, + post, &types.PostInfo{ - CommitmentATX: types.RandomATXID(), + CommitmentATX: commitmentATX, Nonce: new(types.VRFPostIndex), + NumUnits: numUnits, + NodeID: sig.NodeID(), + LabelsPerUnit: tab.conf.LabelsPerUnit, }, nil, ) @@ -249,7 +267,6 @@ func Test_Builder_Multi_InitialPost(t *testing.T) { func Test_Builder_Multi_HappyPath(t *testing.T) { layerDuration := 2 * time.Second tab := newTestBuilder(t, 3, WithPoetConfig(PoetConfig{PhaseShift: layerDuration * 4, CycleGap: layerDuration})) - tab.regossipInterval = 0 // disable regossip for testing // step 1: build initial posts initialPostChan := make(chan struct{}) @@ -264,12 +281,23 @@ func Test_Builder_Multi_HappyPath(t *testing.T) { Nonce: rand.Uint32(), Pow: rand.Uint64(), - NumUnits: 4, + NumUnits: uint32(12), CommitmentATX: types.RandomATXID(), VRFNonce: types.VRFPostIndex(rand.Uint64()), } initialPost[sig.NodeID()] = &nipost + post := &types.Post{ + Indices: nipost.Indices, + Nonce: nipost.Nonce, + Pow: nipost.Pow, + } + meta := &types.PostMetadata{ + Challenge: shared.ZeroChallenge, + LabelsPerUnit: tab.conf.LabelsPerUnit, + } + tab.mValidator.EXPECT().Post(gomock.Any(), sig.NodeID(), nipost.CommitmentATX, post, meta, nipost.NumUnits). + Return(nil) tab.mnipost.EXPECT().Proof(gomock.Any(), sig.NodeID(), shared.ZeroChallenge).DoAndReturn( func(ctx context.Context, _ types.NodeID, _ []byte) (*types.Post, *types.PostInfo, error) { <-initialPostChan @@ -283,6 +311,7 @@ func Test_Builder_Multi_HappyPath(t *testing.T) { NumUnits: nipost.NumUnits, CommitmentATX: nipost.CommitmentATX, Nonce: &nipost.VRFNonce, + LabelsPerUnit: tab.conf.LabelsPerUnit, } return post, postInfo, nil @@ -315,6 +344,19 @@ func Test_Builder_Multi_HappyPath(t *testing.T) { return postGenesisEpoch.FirstLayer() + 1 }, ) + + nipost := initialPost[sig.NodeID()] + post := &types.Post{ + Indices: nipost.Indices, + Nonce: nipost.Nonce, + Pow: nipost.Pow, + } + meta := &types.PostMetadata{ + Challenge: shared.ZeroChallenge, + LabelsPerUnit: tab.conf.LabelsPerUnit, + } + tab.mValidator.EXPECT().Post(gomock.Any(), sig.NodeID(), nipost.CommitmentATX, post, meta, nipost.NumUnits). + Return(nil) } // step 3: create ATX diff --git a/activation/activation_test.go b/activation/activation_test.go index 3b6fa1e6e6..1e9704f710 100644 --- a/activation/activation_test.go +++ b/activation/activation_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "math/rand" "os" "testing" "time" @@ -143,7 +144,8 @@ func newTestBuilder(tb testing.TB, numSigners int, opts ...BuilderOption) *testA opts = append(opts, WithValidator(tab.mValidator)) cfg := Config{ - GoldenATXID: tab.goldenATXID, + GoldenATXID: tab.goldenATXID, + LabelsPerUnit: DefaultPostConfig().LabelsPerUnit, } tab.msync.EXPECT().RegisterForATXSynced().DoAndReturn(closedChan).AnyTimes() @@ -785,11 +787,27 @@ func TestBuilder_PublishActivationTx_NoPrevATX(t *testing.T) { require.NoError(t, atxs.Add(tab.cdb, vPosAtx)) // generate and store initial post in state - require.NoError(t, nipost.AddInitialPost( - tab.localDb, - sig.NodeID(), - nipost.Post{Indices: make([]byte, 10)}, - )) + post := nipost.Post{ + Indices: types.RandomBytes(10), + Nonce: rand.Uint32(), + Pow: rand.Uint64(), + + NumUnits: uint32(12), + CommitmentATX: types.RandomATXID(), + VRFNonce: types.VRFPostIndex(rand.Uint64()), + } + require.NoError(t, nipost.AddInitialPost(tab.localDb, sig.NodeID(), post)) + initialPost := &types.Post{ + Nonce: post.Nonce, + Indices: post.Indices, + Pow: post.Pow, + } + meta := &types.PostMetadata{ + Challenge: shared.ZeroChallenge, + LabelsPerUnit: tab.conf.LabelsPerUnit, + } + tab.mValidator.EXPECT().Post(gomock.Any(), sig.NodeID(), post.CommitmentATX, initialPost, meta, post.NumUnits). + Return(nil) // create and publish ATX tab.mclock.EXPECT().CurrentLayer().Return(currLayer).AnyTimes() @@ -824,12 +842,28 @@ func TestBuilder_PublishActivationTx_NoPrevATX_PublishFails_InitialPost_preserve require.NoError(t, err) require.NoError(t, atxs.Add(tab.cdb, vPosAtx)) - // generate and store initial post in state + // generate and store initial refPost in state refPost := nipost.Post{ - Indices: make([]byte, 10), + Indices: types.RandomBytes(10), + Nonce: rand.Uint32(), + Pow: rand.Uint64(), + + NumUnits: uint32(12), CommitmentATX: types.RandomATXID(), + VRFNonce: types.VRFPostIndex(rand.Uint64()), } require.NoError(t, nipost.AddInitialPost(tab.localDb, sig.NodeID(), refPost)) + initialPost := &types.Post{ + Nonce: refPost.Nonce, + Indices: refPost.Indices, + Pow: refPost.Pow, + } + meta := &types.PostMetadata{ + Challenge: shared.ZeroChallenge, + LabelsPerUnit: tab.conf.LabelsPerUnit, + } + tab.mValidator.EXPECT().Post(gomock.Any(), sig.NodeID(), refPost.CommitmentATX, initialPost, meta, refPost.NumUnits). + Return(nil) // create and publish ATX tab.mclock.EXPECT().CurrentLayer().Return(currLayer).AnyTimes() @@ -1088,7 +1122,27 @@ func TestBuilder_PublishActivationTx_TargetsEpochBasedOnPosAtx(t *testing.T) { return nil }) - require.NoError(t, nipost.AddInitialPost(tab.localDb, sig.NodeID(), nipost.Post{Indices: make([]byte, 10)})) + post := nipost.Post{ + Indices: types.RandomBytes(10), + Nonce: rand.Uint32(), + Pow: rand.Uint64(), + + NumUnits: uint32(12), + CommitmentATX: types.RandomATXID(), + VRFNonce: types.VRFPostIndex(rand.Uint64()), + } + require.NoError(t, nipost.AddInitialPost(tab.localDb, sig.NodeID(), post)) + initialPost := &types.Post{ + Nonce: post.Nonce, + Indices: post.Indices, + Pow: post.Pow, + } + meta := &types.PostMetadata{ + Challenge: shared.ZeroChallenge, + LabelsPerUnit: tab.conf.LabelsPerUnit, + } + tab.mValidator.EXPECT().Post(gomock.Any(), sig.NodeID(), post.CommitmentATX, initialPost, meta, post.NumUnits). + Return(nil) tab.mnipost.EXPECT().ResetState(sig.NodeID()).Return(nil) @@ -1349,14 +1403,38 @@ func TestBuilder_InitialProofGeneratedOnce(t *testing.T) { tab := newTestBuilder(t, 1, WithPoetConfig(PoetConfig{PhaseShift: layerDuration * 4})) sig := maps.Values(tab.signers)[0] + post := nipost.Post{ + Indices: types.RandomBytes(10), + Nonce: rand.Uint32(), + Pow: rand.Uint64(), + + NumUnits: uint32(12), + CommitmentATX: types.RandomATXID(), + VRFNonce: types.VRFPostIndex(rand.Uint64()), + } + initialPost := &types.Post{ + Nonce: post.Nonce, + Indices: post.Indices, + Pow: post.Pow, + } tab.mnipost.EXPECT().Proof(gomock.Any(), sig.NodeID(), shared.ZeroChallenge).Return( - &types.Post{Indices: make([]byte, 10)}, + initialPost, &types.PostInfo{ - CommitmentATX: types.RandomATXID(), - Nonce: new(types.VRFPostIndex), + NodeID: sig.NodeID(), + CommitmentATX: post.CommitmentATX, + Nonce: &post.VRFNonce, + + NumUnits: post.NumUnits, + LabelsPerUnit: tab.conf.LabelsPerUnit, }, nil, ) + meta := &types.PostMetadata{ + Challenge: shared.ZeroChallenge, + LabelsPerUnit: tab.conf.LabelsPerUnit, + } + tab.mValidator.EXPECT().Post(gomock.Any(), sig.NodeID(), post.CommitmentATX, initialPost, meta, post.NumUnits). + Return(nil) require.NoError(t, tab.buildInitialPost(context.Background(), sig.NodeID())) posEpoch := postGenesisEpoch + 1 @@ -1398,14 +1476,32 @@ func TestBuilder_InitialPostIsPersisted(t *testing.T) { tab := newTestBuilder(t, 1, WithPoetConfig(PoetConfig{PhaseShift: layerDuration * 4})) sig := maps.Values(tab.signers)[0] + commitmentATX := types.RandomATXID() + nonce := types.VRFPostIndex(rand.Uint64()) + numUnits := uint32(12) + initialPost := &types.Post{ + Nonce: rand.Uint32(), + Indices: types.RandomBytes(10), + Pow: rand.Uint64(), + } + meta := &types.PostMetadata{ + Challenge: shared.ZeroChallenge, + LabelsPerUnit: tab.conf.LabelsPerUnit, + } tab.mnipost.EXPECT().Proof(gomock.Any(), sig.NodeID(), shared.ZeroChallenge).Return( - &types.Post{Indices: make([]byte, 10)}, + initialPost, &types.PostInfo{ - CommitmentATX: types.RandomATXID(), - Nonce: new(types.VRFPostIndex), + NodeID: sig.NodeID(), + CommitmentATX: commitmentATX, + Nonce: &nonce, + + NumUnits: numUnits, + LabelsPerUnit: tab.conf.LabelsPerUnit, }, nil, ) + tab.mValidator.EXPECT().Post(gomock.Any(), sig.NodeID(), commitmentATX, initialPost, meta, numUnits). + Return(nil) require.NoError(t, tab.buildInitialPost(context.Background(), sig.NodeID())) // postClient.Proof() should not be called again @@ -1454,11 +1550,27 @@ func TestWaitPositioningAtx(t *testing.T) { return nil }) - require.NoError(t, nipost.AddInitialPost( - tab.localDb, - sig.NodeID(), - nipost.Post{Indices: make([]byte, 10)}, - )) + post := nipost.Post{ + Indices: types.RandomBytes(10), + Nonce: rand.Uint32(), + Pow: rand.Uint64(), + + NumUnits: uint32(12), + CommitmentATX: types.RandomATXID(), + VRFNonce: types.VRFPostIndex(rand.Uint64()), + } + require.NoError(t, nipost.AddInitialPost(tab.localDb, sig.NodeID(), post)) + initialPost := &types.Post{ + Nonce: post.Nonce, + Indices: post.Indices, + Pow: post.Pow, + } + meta := &types.PostMetadata{ + Challenge: shared.ZeroChallenge, + LabelsPerUnit: tab.conf.LabelsPerUnit, + } + tab.mValidator.EXPECT().Post(gomock.Any(), sig.NodeID(), post.CommitmentATX, initialPost, meta, post.NumUnits). + Return(nil) require.NoError(t, tab.PublishActivationTx(context.Background(), sig)) }) diff --git a/activation/e2e/activation_test.go b/activation/e2e/activation_test.go index d27b25ca57..b69124883a 100644 --- a/activation/e2e/activation_test.go +++ b/activation/e2e/activation_test.go @@ -2,6 +2,7 @@ package activation_test import ( "context" + "math/rand" "sync" "testing" "time" @@ -66,12 +67,13 @@ func Test_BuilderWithMultipleClients(t *testing.T) { opts := opts eg.Go(func() error { validator := activation.NewMocknipostValidator(ctrl) - mgr, err := activation.NewPostSetupManager(sig.NodeID(), cfg, logger, cdb, goldenATX, syncer, validator) + mgr, err := activation.NewPostSetupManager(cfg, logger, cdb, goldenATX, syncer, validator) require.NoError(t, err) opts.DataDir = t.TempDir() - initPost(t, mgr, opts) - t.Cleanup(launchPostSupervisor(t, logger, mgr, grpcCfg, opts)) + opts.NumUnits = uint32(rand.Int31n(int32(cfg.MaxNumUnits/2-cfg.MinNumUnits))) + cfg.MinNumUnits + initPost(t, mgr, opts, sig.NodeID()) + t.Cleanup(launchPostSupervisor(t, logger, mgr, sig.NodeID(), grpcCfg, opts)) require.Eventually(t, func() bool { _, err := svc.Client(sig.NodeID()) @@ -148,6 +150,10 @@ func Test_BuilderWithMultipleClients(t *testing.T) { }, ).Times(numSigners) + verifier, err := activation.NewPostVerifier(cfg, logger.Named("verifier")) + require.NoError(t, err) + t.Cleanup(func() { assert.NoError(t, verifier.Close()) }) + v := activation.NewValidator(nil, poetDb, cfg, opts.Scrypt, verifier) tab := activation.NewBuilder( conf, cdb, @@ -158,22 +164,16 @@ func Test_BuilderWithMultipleClients(t *testing.T) { syncer, logger, activation.WithPoetConfig(poetCfg), + activation.WithValidator(v), ) for _, sig := range signers { tab.Register(sig) } require.NoError(t, tab.StartSmeshing(types.Address{})) - <-endChan - require.NoError(t, tab.StopSmeshing(false)) - verifier, err := activation.NewPostVerifier(cfg, logger.Named("verifier")) - require.NoError(t, err) - t.Cleanup(func() { assert.NoError(t, verifier.Close()) }) - - v := activation.NewValidator(nil, poetDb, cfg, opts.Scrypt, verifier) for _, sig := range signers { atx := atxs[sig.NodeID()] diff --git a/activation/e2e/nipost_test.go b/activation/e2e/nipost_test.go index 1ba62cef65..923c438ff6 100644 --- a/activation/e2e/nipost_test.go +++ b/activation/e2e/nipost_test.go @@ -64,6 +64,7 @@ func launchPostSupervisor( tb testing.TB, log *zap.Logger, mgr *activation.PostSetupManager, + id types.NodeID, cfg grpcserver.Config, postOpts activation.PostSetupOpts, ) func() { @@ -76,7 +77,7 @@ func launchPostSupervisor( ps, err := activation.NewPostSupervisor(log, cmdCfg, postCfg, provingOpts, mgr) require.NoError(tb, err) require.NotNil(tb, ps) - require.NoError(tb, ps.Start(postOpts)) + require.NoError(tb, ps.Start(postOpts, id)) return func() { assert.NoError(tb, ps.Stop(false)) } } @@ -99,12 +100,12 @@ func launchServer(tb testing.TB, services ...grpcserver.ServiceAPI) (grpcserver. return cfg, func() { assert.NoError(tb, server.Close()) } } -func initPost(tb testing.TB, mgr *activation.PostSetupManager, opts activation.PostSetupOpts) { +func initPost(tb testing.TB, mgr *activation.PostSetupManager, opts activation.PostSetupOpts, id types.NodeID) { tb.Helper() // Create data. - require.NoError(tb, mgr.PrepareInitializer(context.Background(), opts)) - require.NoError(tb, mgr.StartSession(context.Background())) + require.NoError(tb, mgr.PrepareInitializer(context.Background(), opts, id)) + require.NoError(tb, mgr.StartSession(context.Background(), id)) require.Equal(tb, activation.PostSetupStateComplete, mgr.Status().State) } @@ -128,14 +129,14 @@ func TestNIPostBuilderWithClients(t *testing.T) { }) validator := activation.NewMocknipostValidator(ctrl) - mgr, err := activation.NewPostSetupManager(sig.NodeID(), cfg, logger, cdb, goldenATX, syncer, validator) + mgr, err := activation.NewPostSetupManager(cfg, logger, cdb, goldenATX, syncer, validator) require.NoError(t, err) opts := activation.DefaultPostSetupOpts() opts.DataDir = t.TempDir() opts.ProviderID.SetUint32(initialization.CPUProviderID()) opts.Scrypt.N = 2 // Speedup initialization in tests. - initPost(t, mgr, opts) + initPost(t, mgr, opts, sig.NodeID()) // ensure that genesis aligns with layer timings genesis := time.Now().Add(layerDuration).Round(layerDuration) @@ -173,7 +174,7 @@ func TestNIPostBuilderWithClients(t *testing.T) { grpcCfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) - t.Cleanup(launchPostSupervisor(t, logger, mgr, grpcCfg, opts)) + t.Cleanup(launchPostSupervisor(t, logger, mgr, sig.NodeID(), grpcCfg, opts)) require.Eventually(t, func() bool { _, err := svc.Client(sig.NodeID()) @@ -274,7 +275,7 @@ func TestNewNIPostBuilderNotInitialized(t *testing.T) { }) validator := activation.NewMocknipostValidator(ctrl) - mgr, err := activation.NewPostSetupManager(sig.NodeID(), cfg, logger, cdb, goldenATX, syncer, validator) + mgr, err := activation.NewPostSetupManager(cfg, logger, cdb, goldenATX, syncer, validator) require.NoError(t, err) // ensure that genesis aligns with layer timings @@ -325,7 +326,7 @@ func TestNewNIPostBuilderNotInitialized(t *testing.T) { opts.DataDir = t.TempDir() opts.ProviderID.SetUint32(initialization.CPUProviderID()) opts.Scrypt.N = 2 // Speedup initialization in tests. - t.Cleanup(launchPostSupervisor(t, logger, mgr, grpcCfg, opts)) + t.Cleanup(launchPostSupervisor(t, logger, mgr, sig.NodeID(), grpcCfg, opts)) require.Eventually(t, func() bool { _, err := svc.Client(sig.NodeID()) @@ -393,12 +394,12 @@ func Test_NIPostBuilderWithMultipleClients(t *testing.T) { sig := sig opts := opts eg.Go(func() error { - mgr, err := activation.NewPostSetupManager(sig.NodeID(), cfg, logger, cdb, goldenATX, syncer, validator) + mgr, err := activation.NewPostSetupManager(cfg, logger, cdb, goldenATX, syncer, validator) require.NoError(t, err) opts.DataDir = t.TempDir() - initPost(t, mgr, opts) - t.Cleanup(launchPostSupervisor(t, logger, mgr, grpcCfg, opts)) + initPost(t, mgr, opts, sig.NodeID()) + t.Cleanup(launchPostSupervisor(t, logger, mgr, sig.NodeID(), grpcCfg, opts)) require.Eventually(t, func() bool { _, err := svc.Client(sig.NodeID()) diff --git a/activation/e2e/validation_test.go b/activation/e2e/validation_test.go index 376a059f91..045e0fc407 100644 --- a/activation/e2e/validation_test.go +++ b/activation/e2e/validation_test.go @@ -42,14 +42,14 @@ func TestValidator_Validate(t *testing.T) { return synced }) - mgr, err := activation.NewPostSetupManager(sig.NodeID(), cfg, logger, cdb, goldenATX, syncer, validator) + mgr, err := activation.NewPostSetupManager(cfg, logger, cdb, goldenATX, syncer, validator) require.NoError(t, err) opts := activation.DefaultPostSetupOpts() opts.DataDir = t.TempDir() opts.ProviderID.SetUint32(initialization.CPUProviderID()) opts.Scrypt.N = 2 // Speedup initialization in tests. - initPost(t, mgr, opts) + initPost(t, mgr, opts, sig.NodeID()) // ensure that genesis aligns with layer timings genesis := time.Now().Add(layerDuration).Round(layerDuration) @@ -87,7 +87,7 @@ func TestValidator_Validate(t *testing.T) { grpcCfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) - t.Cleanup(launchPostSupervisor(t, logger, mgr, grpcCfg, opts)) + t.Cleanup(launchPostSupervisor(t, logger, mgr, sig.NodeID(), grpcCfg, opts)) require.Eventually(t, func() bool { _, err := svc.Client(sig.NodeID()) diff --git a/activation/handler.go b/activation/handler.go index 0a93dd3a2a..7856ecc6f7 100644 --- a/activation/handler.go +++ b/activation/handler.go @@ -9,7 +9,6 @@ import ( "github.com/spacemeshos/post/shared" "github.com/spacemeshos/post/verifying" - "go.uber.org/zap" "golang.org/x/exp/maps" "github.com/spacemeshos/go-spacemesh/atxsdata" @@ -100,11 +99,11 @@ func (h *Handler) Register(sig *signing.EdSigner) { h.signerMtx.Lock() defer h.signerMtx.Unlock() if _, exists := h.signers[sig.NodeID()]; exists { - h.log.Error("signing key already registered", zap.Stringer("id", sig.NodeID())) + h.log.With().Error("signing key already registered", log.ShortStringer("id", sig.NodeID())) return } - h.log.Info("registered signing key", zap.Stringer("id", sig.NodeID())) + h.log.With().Info("registered signing key", log.ShortStringer("id", sig.NodeID())) h.signers[sig.NodeID()] = sig } diff --git a/activation/interface.go b/activation/interface.go index e196f9eeeb..05ab56aa5a 100644 --- a/activation/interface.go +++ b/activation/interface.go @@ -97,8 +97,8 @@ type atxProvider interface { // This interface is used by the atx builder and currently implemented by the PostSetupManager. // Eventually most of the functionality will be moved to the PoSTClient. type postSetupProvider interface { - PrepareInitializer(ctx context.Context, opts PostSetupOpts) error - StartSession(context context.Context) error + PrepareInitializer(ctx context.Context, opts PostSetupOpts, id types.NodeID) error + StartSession(context context.Context, id types.NodeID) error Status() *PostSetupStatus Reset() error } diff --git a/activation/mocks.go b/activation/mocks.go index ab719f0caa..3456b0bda7 100644 --- a/activation/mocks.go +++ b/activation/mocks.go @@ -1087,17 +1087,17 @@ func (m *MockpostSetupProvider) EXPECT() *MockpostSetupProviderMockRecorder { } // PrepareInitializer mocks base method. -func (m *MockpostSetupProvider) PrepareInitializer(ctx context.Context, opts PostSetupOpts) error { +func (m *MockpostSetupProvider) PrepareInitializer(ctx context.Context, opts PostSetupOpts, id types.NodeID) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PrepareInitializer", ctx, opts) + ret := m.ctrl.Call(m, "PrepareInitializer", ctx, opts, id) ret0, _ := ret[0].(error) return ret0 } // PrepareInitializer indicates an expected call of PrepareInitializer. -func (mr *MockpostSetupProviderMockRecorder) PrepareInitializer(ctx, opts any) *MockpostSetupProviderPrepareInitializerCall { +func (mr *MockpostSetupProviderMockRecorder) PrepareInitializer(ctx, opts, id any) *MockpostSetupProviderPrepareInitializerCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrepareInitializer", reflect.TypeOf((*MockpostSetupProvider)(nil).PrepareInitializer), ctx, opts) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrepareInitializer", reflect.TypeOf((*MockpostSetupProvider)(nil).PrepareInitializer), ctx, opts, id) return &MockpostSetupProviderPrepareInitializerCall{Call: call} } @@ -1113,13 +1113,13 @@ func (c *MockpostSetupProviderPrepareInitializerCall) Return(arg0 error) *Mockpo } // Do rewrite *gomock.Call.Do -func (c *MockpostSetupProviderPrepareInitializerCall) Do(f func(context.Context, PostSetupOpts) error) *MockpostSetupProviderPrepareInitializerCall { +func (c *MockpostSetupProviderPrepareInitializerCall) Do(f func(context.Context, PostSetupOpts, types.NodeID) error) *MockpostSetupProviderPrepareInitializerCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockpostSetupProviderPrepareInitializerCall) DoAndReturn(f func(context.Context, PostSetupOpts) error) *MockpostSetupProviderPrepareInitializerCall { +func (c *MockpostSetupProviderPrepareInitializerCall) DoAndReturn(f func(context.Context, PostSetupOpts, types.NodeID) error) *MockpostSetupProviderPrepareInitializerCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -1163,17 +1163,17 @@ func (c *MockpostSetupProviderResetCall) DoAndReturn(f func() error) *MockpostSe } // StartSession mocks base method. -func (m *MockpostSetupProvider) StartSession(context context.Context) error { +func (m *MockpostSetupProvider) StartSession(context context.Context, id types.NodeID) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "StartSession", context) + ret := m.ctrl.Call(m, "StartSession", context, id) ret0, _ := ret[0].(error) return ret0 } // StartSession indicates an expected call of StartSession. -func (mr *MockpostSetupProviderMockRecorder) StartSession(context any) *MockpostSetupProviderStartSessionCall { +func (mr *MockpostSetupProviderMockRecorder) StartSession(context, id any) *MockpostSetupProviderStartSessionCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartSession", reflect.TypeOf((*MockpostSetupProvider)(nil).StartSession), context) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartSession", reflect.TypeOf((*MockpostSetupProvider)(nil).StartSession), context, id) return &MockpostSetupProviderStartSessionCall{Call: call} } @@ -1189,13 +1189,13 @@ func (c *MockpostSetupProviderStartSessionCall) Return(arg0 error) *MockpostSetu } // Do rewrite *gomock.Call.Do -func (c *MockpostSetupProviderStartSessionCall) Do(f func(context.Context) error) *MockpostSetupProviderStartSessionCall { +func (c *MockpostSetupProviderStartSessionCall) Do(f func(context.Context, types.NodeID) error) *MockpostSetupProviderStartSessionCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockpostSetupProviderStartSessionCall) DoAndReturn(f func(context.Context) error) *MockpostSetupProviderStartSessionCall { +func (c *MockpostSetupProviderStartSessionCall) DoAndReturn(f func(context.Context, types.NodeID) error) *MockpostSetupProviderStartSessionCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/activation/post.go b/activation/post.go index 38ee4e0fde..431e3cb52f 100644 --- a/activation/post.go +++ b/activation/post.go @@ -175,7 +175,6 @@ func (o PostSetupOpts) ToInitOpts() config.InitOpts { // PostSetupManager implements the PostProvider interface. type PostSetupManager struct { - id types.NodeID commitmentAtxId types.ATXID syncer syncer @@ -206,7 +205,6 @@ func PostValidityDelay(delay time.Duration) PostSetupManagerOpt { // NewPostSetupManager creates a new instance of PostSetupManager. func NewPostSetupManager( - id types.NodeID, cfg PostConfig, logger *zap.Logger, db *datastore.CachedDB, @@ -216,7 +214,6 @@ func NewPostSetupManager( opts ...PostSetupManagerOpt, ) (*PostSetupManager, error) { mgr := &PostSetupManager{ - id: id, cfg: cfg, logger: logger, db: db, @@ -260,7 +257,7 @@ func (mgr *PostSetupManager) Status() *PostSetupStatus { // previously started session, and will return an error if a session is already // in progress. It must be ensured that PrepareInitializer is called once // before each call to StartSession and that the node is ATX synced. -func (mgr *PostSetupManager) StartSession(ctx context.Context) error { +func (mgr *PostSetupManager) StartSession(ctx context.Context, id types.NodeID) error { // Ensure only one goroutine can execute initialization at a time. err := func() error { mgr.mu.Lock() @@ -275,7 +272,7 @@ func (mgr *PostSetupManager) StartSession(ctx context.Context) error { return err } mgr.logger.Info("post setup session starting", - zap.Stringer("node_id", mgr.id), + zap.Stringer("node_id", id), zap.Stringer("commitment_atx", mgr.commitmentAtxId), zap.String("data_dir", mgr.lastOpts.DataDir), zap.Uint32("num_units", mgr.lastOpts.NumUnits), @@ -283,7 +280,7 @@ func (mgr *PostSetupManager) StartSession(ctx context.Context) error { zap.Stringer("provider", mgr.lastOpts.ProviderID), ) public.InitStart.Set(float64(mgr.lastOpts.NumUnits)) - events.EmitInitStart(mgr.id, mgr.commitmentAtxId) + events.EmitInitStart(id, mgr.commitmentAtxId) err = mgr.init.Initialize(ctx) mgr.mu.Lock() @@ -300,19 +297,19 @@ func (mgr *PostSetupManager) StartSession(ctx context.Context) error { zap.Error(errLabelMismatch), ) mgr.state = PostSetupStateError - events.EmitInitFailure(mgr.id, mgr.commitmentAtxId, errLabelMismatch) + events.EmitInitFailure(id, mgr.commitmentAtxId, errLabelMismatch) return nil case err != nil: mgr.logger.Error("post setup session failed", zap.Error(err)) mgr.state = PostSetupStateError - events.EmitInitFailure(mgr.id, mgr.commitmentAtxId, err) + events.EmitInitFailure(id, mgr.commitmentAtxId, err) return err } public.InitEnd.Set(float64(mgr.lastOpts.NumUnits)) events.EmitInitComplete() mgr.logger.Info("post setup completed", - zap.Stringer("node_id", mgr.id), + zap.Stringer("node_id", id), zap.Stringer("commitment_atx", mgr.commitmentAtxId), zap.String("data_dir", mgr.lastOpts.DataDir), zap.Uint32("num_units", mgr.lastOpts.NumUnits), @@ -330,7 +327,7 @@ func (mgr *PostSetupManager) StartSession(ctx context.Context) error { // (StartSession can take days to complete). After the first call to this // method subsequent calls to this method will return an error until // StartSession has completed execution. -func (mgr *PostSetupManager) PrepareInitializer(ctx context.Context, opts PostSetupOpts) error { +func (mgr *PostSetupManager) PrepareInitializer(ctx context.Context, opts PostSetupOpts, id types.NodeID) error { mgr.logger.Info("preparing post initializer", zap.Any("opts", opts)) mgr.mu.Lock() defer mgr.mu.Unlock() @@ -339,13 +336,13 @@ func (mgr *PostSetupManager) PrepareInitializer(ctx context.Context, opts PostSe } var err error - mgr.commitmentAtxId, err = mgr.commitmentAtx(ctx, opts.DataDir) + mgr.commitmentAtxId, err = mgr.commitmentAtx(ctx, opts.DataDir, id) if err != nil { return err } newInit, err := initialization.NewInitializer( - initialization.WithNodeId(mgr.id.Bytes()), + initialization.WithNodeId(id.Bytes()), initialization.WithCommitmentAtxId(mgr.commitmentAtxId.Bytes()), initialization.WithConfig(mgr.cfg.ToConfig()), initialization.WithInitOpts(opts.ToInitOpts()), @@ -362,14 +359,14 @@ func (mgr *PostSetupManager) PrepareInitializer(ctx context.Context, opts PostSe return nil } -func (mgr *PostSetupManager) commitmentAtx(ctx context.Context, dataDir string) (types.ATXID, error) { +func (mgr *PostSetupManager) commitmentAtx(ctx context.Context, dataDir string, id types.NodeID) (types.ATXID, error) { m, err := initialization.LoadMetadata(dataDir) switch { case err == nil: return types.ATXID(types.BytesToHash(m.CommitmentAtxId)), nil case errors.Is(err, initialization.ErrStateMetadataFileMissing): // if this node has already published an ATX, get its initial ATX and from it the commitment ATX - atxId, err := atxs.GetFirstIDByNodeID(mgr.db, mgr.id) + atxId, err := atxs.GetFirstIDByNodeID(mgr.db, id) if err == nil { atx, err := atxs.Get(mgr.db, atxId) if err != nil { diff --git a/activation/post_supervisor.go b/activation/post_supervisor.go index c59f04ec81..f33928ffea 100644 --- a/activation/post_supervisor.go +++ b/activation/post_supervisor.go @@ -18,6 +18,7 @@ import ( "go.uber.org/zap" "golang.org/x/sync/errgroup" + "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" ) @@ -131,7 +132,7 @@ func (ps *PostSupervisor) Status() *PostSetupStatus { return ps.postSetupProvider.Status() } -func (ps *PostSupervisor) Start(opts PostSetupOpts) error { +func (ps *PostSupervisor) Start(opts PostSetupOpts, id types.NodeID) error { ps.mtx.Lock() defer ps.mtx.Unlock() if ps.stop != nil { @@ -146,7 +147,7 @@ func (ps *PostSupervisor) Start(opts PostSetupOpts) error { ps.eg.Go(func() error { // If it returns any error other than context.Canceled // (which is how we signal it to stop) then we shutdown. - err := ps.postSetupProvider.PrepareInitializer(ctx, opts) + err := ps.postSetupProvider.PrepareInitializer(ctx, opts, id) switch { case errors.Is(err, context.Canceled): return nil @@ -155,7 +156,7 @@ func (ps *PostSupervisor) Start(opts PostSetupOpts) error { return err } - err = ps.postSetupProvider.StartSession(ctx) + err = ps.postSetupProvider.StartSession(ctx, id) switch { case errors.Is(err, context.Canceled): return nil diff --git a/activation/post_supervisor_test.go b/activation/post_supervisor_test.go index cee1657e1d..f0a2900c08 100644 --- a/activation/post_supervisor_test.go +++ b/activation/post_supervisor_test.go @@ -15,6 +15,8 @@ import ( "go.uber.org/zap" "go.uber.org/zap/zapcore" "go.uber.org/zap/zaptest" + + "github.com/spacemeshos/go-spacemesh/common/types" ) func closedChan() <-chan struct{} { @@ -57,16 +59,17 @@ func Test_PostSupervisor_Start_FailPrepare(t *testing.T) { postCfg := DefaultPostConfig() postOpts := DefaultPostSetupOpts() provingOpts := DefaultPostProvingOpts() + nodeID := types.RandomNodeID() mgr := NewMockpostSetupProvider(gomock.NewController(t)) testErr := errors.New("test error") - mgr.EXPECT().PrepareInitializer(gomock.Any(), postOpts).Return(testErr) + mgr.EXPECT().PrepareInitializer(gomock.Any(), postOpts, nodeID).Return(testErr) ps, err := NewPostSupervisor(log.Named("supervisor"), cmdCfg, postCfg, provingOpts, mgr) require.NoError(t, err) require.NotNil(t, ps) - require.NoError(t, ps.Start(postOpts)) + require.NoError(t, ps.Start(postOpts, nodeID)) require.ErrorIs(t, ps.Stop(false), testErr) } @@ -91,17 +94,18 @@ func Test_PostSupervisor_Start_FailStartSession(t *testing.T) { postCfg := DefaultPostConfig() postOpts := DefaultPostSetupOpts() provingOpts := DefaultPostProvingOpts() + nodeID := types.RandomNodeID() ctrl := gomock.NewController(t) mgr := NewMockpostSetupProvider(ctrl) - mgr.EXPECT().PrepareInitializer(gomock.Any(), postOpts).Return(nil) - mgr.EXPECT().StartSession(gomock.Any()).Return(errors.New("failed start session")) + mgr.EXPECT().PrepareInitializer(gomock.Any(), postOpts, nodeID).Return(nil) + mgr.EXPECT().StartSession(gomock.Any(), nodeID).Return(errors.New("failed start session")) ps, err := NewPostSupervisor(log.Named("supervisor"), cmdCfg, postCfg, provingOpts, mgr) require.NoError(t, err) require.NotNil(t, ps) - require.NoError(t, ps.Start(postOpts)) + require.NoError(t, ps.Start(postOpts, nodeID)) require.EqualError(t, ps.eg.Wait(), "failed start session") } @@ -112,17 +116,18 @@ func Test_PostSupervisor_StartsServiceCmd(t *testing.T) { postCfg := DefaultPostConfig() postOpts := DefaultPostSetupOpts() provingOpts := DefaultPostProvingOpts() + nodeID := types.RandomNodeID() ctrl := gomock.NewController(t) mgr := NewMockpostSetupProvider(ctrl) - mgr.EXPECT().PrepareInitializer(gomock.Any(), postOpts).Return(nil) - mgr.EXPECT().StartSession(gomock.Any()).Return(nil) + mgr.EXPECT().PrepareInitializer(gomock.Any(), postOpts, nodeID).Return(nil) + mgr.EXPECT().StartSession(gomock.Any(), nodeID).Return(nil) ps, err := NewPostSupervisor(log.Named("supervisor"), cmdCfg, postCfg, provingOpts, mgr) require.NoError(t, err) require.NotNil(t, ps) - require.NoError(t, ps.Start(postOpts)) + require.NoError(t, ps.Start(postOpts, nodeID)) t.Cleanup(func() { assert.NoError(t, ps.Stop(false)) }) require.Eventually(t, func() bool { return ps.pid.Load() != 0 }, 5*time.Second, 100*time.Millisecond) @@ -150,26 +155,27 @@ func Test_PostSupervisor_Restart_Possible(t *testing.T) { postCfg := DefaultPostConfig() postOpts := DefaultPostSetupOpts() provingOpts := DefaultPostProvingOpts() + nodeID := types.RandomNodeID() ctrl := gomock.NewController(t) mgr := NewMockpostSetupProvider(ctrl) - mgr.EXPECT().PrepareInitializer(gomock.Any(), postOpts).Return(nil) - mgr.EXPECT().StartSession(gomock.Any()).Return(nil) + mgr.EXPECT().PrepareInitializer(gomock.Any(), postOpts, nodeID).Return(nil) + mgr.EXPECT().StartSession(gomock.Any(), nodeID).Return(nil) ps, err := NewPostSupervisor(log.Named("supervisor"), cmdCfg, postCfg, provingOpts, mgr) require.NoError(t, err) require.NotNil(t, ps) - require.NoError(t, ps.Start(postOpts)) + require.NoError(t, ps.Start(postOpts, nodeID)) t.Cleanup(func() { assert.NoError(t, ps.Stop(false)) }) require.Eventually(t, func() bool { return ps.pid.Load() != 0 }, 5*time.Second, 100*time.Millisecond) require.NoError(t, ps.Stop(false)) require.Eventually(t, func() bool { return ps.pid.Load() == 0 }, 5*time.Second, 100*time.Millisecond) - mgr.EXPECT().PrepareInitializer(gomock.Any(), postOpts).Return(nil) - mgr.EXPECT().StartSession(gomock.Any()).Return(nil) - require.NoError(t, ps.Start(postOpts)) + mgr.EXPECT().PrepareInitializer(gomock.Any(), postOpts, nodeID).Return(nil) + mgr.EXPECT().StartSession(gomock.Any(), nodeID).Return(nil) + require.NoError(t, ps.Start(postOpts, nodeID)) require.Eventually(t, func() bool { return ps.pid.Load() != 0 }, 5*time.Second, 100*time.Millisecond) require.NoError(t, ps.Stop(false)) @@ -183,17 +189,18 @@ func Test_PostSupervisor_LogFatalOnCrash(t *testing.T) { postCfg := DefaultPostConfig() postOpts := DefaultPostSetupOpts() provingOpts := DefaultPostProvingOpts() + nodeID := types.RandomNodeID() ctrl := gomock.NewController(t) mgr := NewMockpostSetupProvider(ctrl) - mgr.EXPECT().PrepareInitializer(gomock.Any(), postOpts).Return(nil) - mgr.EXPECT().StartSession(gomock.Any()).Return(nil) + mgr.EXPECT().PrepareInitializer(gomock.Any(), postOpts, nodeID).Return(nil) + mgr.EXPECT().StartSession(gomock.Any(), nodeID).Return(nil) ps, err := NewPostSupervisor(log.Named("supervisor"), cmdCfg, postCfg, provingOpts, mgr) require.NoError(t, err) require.NotNil(t, ps) - require.NoError(t, ps.Start(postOpts)) + require.NoError(t, ps.Start(postOpts, nodeID)) t.Cleanup(func() { assert.NoError(t, ps.Stop(false)) }) require.Eventually(t, func() bool { return ps.pid.Load() != 0 }, 5*time.Second, 100*time.Millisecond) @@ -217,17 +224,18 @@ func Test_PostSupervisor_LogFatalOnInvalidConfig(t *testing.T) { postCfg := DefaultPostConfig() postOpts := DefaultPostSetupOpts() provingOpts := DefaultPostProvingOpts() + nodeID := types.RandomNodeID() ctrl := gomock.NewController(t) mgr := NewMockpostSetupProvider(ctrl) - mgr.EXPECT().PrepareInitializer(gomock.Any(), postOpts).Return(nil) - mgr.EXPECT().StartSession(gomock.Any()).Return(nil) + mgr.EXPECT().PrepareInitializer(gomock.Any(), postOpts, nodeID).Return(nil) + mgr.EXPECT().StartSession(gomock.Any(), nodeID).Return(nil) ps, err := NewPostSupervisor(log.Named("supervisor"), cmdCfg, postCfg, provingOpts, mgr) require.NoError(t, err) require.NotNil(t, ps) - require.NoError(t, ps.Start(postOpts)) + require.NoError(t, ps.Start(postOpts, nodeID)) t.Cleanup(func() { assert.NoError(t, ps.Stop(false)) }) require.Eventually(t, func() bool { return ps.pid.Load() != 0 }, 5*time.Second, 100*time.Millisecond) @@ -248,17 +256,18 @@ func Test_PostSupervisor_StopOnError(t *testing.T) { postCfg := DefaultPostConfig() postOpts := DefaultPostSetupOpts() provingOpts := DefaultPostProvingOpts() + nodeID := types.RandomNodeID() ctrl := gomock.NewController(t) mgr := NewMockpostSetupProvider(ctrl) - mgr.EXPECT().PrepareInitializer(gomock.Any(), postOpts).Return(nil) - mgr.EXPECT().StartSession(gomock.Any()).Return(nil) + mgr.EXPECT().PrepareInitializer(gomock.Any(), postOpts, nodeID).Return(nil) + mgr.EXPECT().StartSession(gomock.Any(), nodeID).Return(nil) ps, err := NewPostSupervisor(log.Named("supervisor"), cmdCfg, postCfg, provingOpts, mgr) require.NoError(t, err) require.NotNil(t, ps) - require.NoError(t, ps.Start(postOpts)) + require.NoError(t, ps.Start(postOpts, nodeID)) t.Cleanup(func() { assert.NoError(t, ps.Stop(false)) }) require.Eventually(t, func() bool { return ps.pid.Load() != 0 }, 5*time.Second, 100*time.Millisecond) diff --git a/activation/post_test.go b/activation/post_test.go index 065528228c..be1e85de8a 100644 --- a/activation/post_test.go +++ b/activation/post_test.go @@ -51,21 +51,22 @@ func TestPostSetupManager(t *testing.T) { }) // Create data. - require.NoError(t, mgr.PrepareInitializer(context.Background(), mgr.opts)) - require.NoError(t, mgr.StartSession(context.Background())) + nodeID := types.RandomNodeID() + require.NoError(t, mgr.PrepareInitializer(context.Background(), mgr.opts, nodeID)) + require.NoError(t, mgr.StartSession(context.Background(), nodeID)) require.NoError(t, eg.Wait()) require.Equal(t, PostSetupStateComplete, mgr.Status().State) // Create data (same opts). - require.NoError(t, mgr.PrepareInitializer(context.Background(), mgr.opts)) - require.NoError(t, mgr.StartSession(context.Background())) + require.NoError(t, mgr.PrepareInitializer(context.Background(), mgr.opts, nodeID)) + require.NoError(t, mgr.StartSession(context.Background(), nodeID)) // Cleanup. require.NoError(t, mgr.Reset()) // Create data (same opts, after deletion). - require.NoError(t, mgr.PrepareInitializer(context.Background(), mgr.opts)) - require.NoError(t, mgr.StartSession(context.Background())) + require.NoError(t, mgr.PrepareInitializer(context.Background(), mgr.opts, nodeID)) + require.NoError(t, mgr.StartSession(context.Background(), nodeID)) require.Equal(t, PostSetupStateComplete, mgr.Status().State) } @@ -74,60 +75,58 @@ func TestPostSetupManager(t *testing.T) { // and should be fully tested there but we check a few cases to be sure that // PrepareInitializer will return errors when the opts don't validate. func TestPostSetupManager_PrepareInitializer(t *testing.T) { - req := require.New(t) - mgr := newTestPostManager(t) + nodeID := types.RandomNodeID() // check no error with good options. - req.NoError(mgr.PrepareInitializer(context.Background(), mgr.opts)) + require.NoError(t, mgr.PrepareInitializer(context.Background(), mgr.opts, nodeID)) defaultConfig := config.DefaultConfig() // Check that invalid options return errors opts := mgr.opts opts.ComputeBatchSize = 3 - req.Error(mgr.PrepareInitializer(context.Background(), opts)) + require.Error(t, mgr.PrepareInitializer(context.Background(), opts, nodeID)) opts = mgr.opts opts.NumUnits = defaultConfig.MaxNumUnits + 1 - req.Error(mgr.PrepareInitializer(context.Background(), opts)) + require.Error(t, mgr.PrepareInitializer(context.Background(), opts, nodeID)) opts = mgr.opts opts.NumUnits = defaultConfig.MinNumUnits - 1 - req.Error(mgr.PrepareInitializer(context.Background(), opts)) + require.Error(t, mgr.PrepareInitializer(context.Background(), opts, nodeID)) opts = mgr.opts opts.Scrypt.N = 0 - req.Error(opts.Scrypt.Validate()) - req.Error(mgr.PrepareInitializer(context.Background(), opts)) + require.Error(t, opts.Scrypt.Validate()) + require.Error(t, mgr.PrepareInitializer(context.Background(), opts, nodeID)) } func TestPostSetupManager_StartSession_WithoutProvider_Error(t *testing.T) { - req := require.New(t) - mgr := newTestPostManager(t) mgr.opts.ProviderID.value = nil + nodeID := types.RandomNodeID() + // Create data. - req.NoError(mgr.PrepareInitializer(context.Background(), mgr.opts)) // prepare is fine without provider - req.ErrorContains(mgr.StartSession(context.Background()), "no provider specified") + require.NoError(t, mgr.PrepareInitializer(context.Background(), mgr.opts, nodeID)) // prepare is fine without provider + require.ErrorContains(t, mgr.StartSession(context.Background(), nodeID), "no provider specified") - req.Equal(PostSetupStateError, mgr.Status().State) + require.Equal(t, PostSetupStateError, mgr.Status().State) } func TestPostSetupManager_StartSession_WithoutProviderAfterInit_OK(t *testing.T) { - req := require.New(t) - mgr := newTestPostManager(t) + nodeID := types.RandomNodeID() ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() // Create data. - req.NoError(mgr.PrepareInitializer(ctx, mgr.opts)) - req.NoError(mgr.StartSession(ctx)) + require.NoError(t, mgr.PrepareInitializer(ctx, mgr.opts, nodeID)) + require.NoError(t, mgr.StartSession(ctx, nodeID)) - req.Equal(PostSetupStateComplete, mgr.Status().State) + require.Equal(t, PostSetupStateComplete, mgr.Status().State) cancel() // start Initializer again, but with no provider set @@ -136,150 +135,149 @@ func TestPostSetupManager_StartSession_WithoutProviderAfterInit_OK(t *testing.T) ctx, cancel = context.WithTimeout(context.Background(), time.Second*10) defer cancel() - req.NoError(mgr.PrepareInitializer(ctx, mgr.opts)) - req.NoError(mgr.StartSession(ctx)) + require.NoError(t, mgr.PrepareInitializer(ctx, mgr.opts, nodeID)) + require.NoError(t, mgr.StartSession(ctx, nodeID)) - req.Equal(PostSetupStateComplete, mgr.Status().State) + require.Equal(t, PostSetupStateComplete, mgr.Status().State) } // Checks that the sequence of calls for initialization (first // PrepareInitializer and then StartSession) is enforced. func TestPostSetupManager_InitializationCallSequence(t *testing.T) { - req := require.New(t) - mgr := newTestPostManager(t) + nodeID := types.RandomNodeID() ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() // Should fail since we have not prepared. - req.Error(mgr.StartSession(ctx)) + require.Error(t, mgr.StartSession(ctx, nodeID)) - req.NoError(mgr.PrepareInitializer(ctx, mgr.opts)) + require.NoError(t, mgr.PrepareInitializer(ctx, mgr.opts, nodeID)) // Should fail since we need to call StartSession after PrepareInitializer. - req.Error(mgr.PrepareInitializer(ctx, mgr.opts)) + require.Error(t, mgr.PrepareInitializer(ctx, mgr.opts, nodeID)) - req.NoError(mgr.StartSession(ctx)) + require.NoError(t, mgr.StartSession(ctx, nodeID)) // Should fail since it is required to call PrepareInitializer before each // call to StartSession. - req.Error(mgr.StartSession(ctx)) + require.Error(t, mgr.StartSession(ctx, nodeID)) } func TestPostSetupManager_StateError(t *testing.T) { - req := require.New(t) - mgr := newTestPostManager(t) mgr.opts.NumUnits = 0 - req.Error(mgr.PrepareInitializer(context.Background(), mgr.opts)) + nodeID := types.RandomNodeID() + + require.Error(t, mgr.PrepareInitializer(context.Background(), mgr.opts, nodeID)) // Verify Status returns StateError - req.Equal(PostSetupStateError, mgr.Status().State) + require.Equal(t, PostSetupStateError, mgr.Status().State) } func TestPostSetupManager_InitialStatus(t *testing.T) { - req := require.New(t) - mgr := newTestPostManager(t) + nodeID := types.RandomNodeID() // Verify the initial status. status := mgr.Status() - req.Equal(PostSetupStateNotStarted, status.State) - req.Zero(status.NumLabelsWritten) + require.Equal(t, PostSetupStateNotStarted, status.State) + require.Zero(t, status.NumLabelsWritten) // Create data. - req.NoError(mgr.PrepareInitializer(context.Background(), mgr.opts)) - req.NoError(mgr.StartSession(context.Background())) - req.Equal(PostSetupStateComplete, mgr.Status().State) + require.NoError(t, mgr.PrepareInitializer(context.Background(), mgr.opts, nodeID)) + require.NoError(t, mgr.StartSession(context.Background(), nodeID)) + require.Equal(t, PostSetupStateComplete, mgr.Status().State) // Re-instantiate `PostSetupManager`. mgr = newTestPostManager(t) // Verify the initial status. status = mgr.Status() - req.Equal(PostSetupStateNotStarted, status.State) - req.Zero(status.NumLabelsWritten) + require.Equal(t, PostSetupStateNotStarted, status.State) + require.Zero(t, status.NumLabelsWritten) } func TestPostSetupManager_Stop(t *testing.T) { - req := require.New(t) - mgr := newTestPostManager(t) + nodeID := types.RandomNodeID() // Verify state. status := mgr.Status() - req.Equal(PostSetupStateNotStarted, status.State) - req.Zero(status.NumLabelsWritten) + require.Equal(t, PostSetupStateNotStarted, status.State) + require.Zero(t, status.NumLabelsWritten) // Create data. - req.NoError(mgr.PrepareInitializer(context.Background(), mgr.opts)) - req.NoError(mgr.StartSession(context.Background())) + require.NoError(t, mgr.PrepareInitializer(context.Background(), mgr.opts, nodeID)) + require.NoError(t, mgr.StartSession(context.Background(), nodeID)) // Verify state. - req.Equal(PostSetupStateComplete, mgr.Status().State) + require.Equal(t, PostSetupStateComplete, mgr.Status().State) // Reset. - req.NoError(mgr.Reset()) + require.NoError(t, mgr.Reset()) // Verify state. - req.Equal(PostSetupStateNotStarted, mgr.Status().State) + require.Equal(t, PostSetupStateNotStarted, mgr.Status().State) // Create data again. - req.NoError(mgr.PrepareInitializer(context.Background(), mgr.opts)) - req.NoError(mgr.StartSession(context.Background())) + require.NoError(t, mgr.PrepareInitializer(context.Background(), mgr.opts, nodeID)) + require.NoError(t, mgr.StartSession(context.Background(), nodeID)) // Verify state. - req.Equal(PostSetupStateComplete, mgr.Status().State) + require.Equal(t, PostSetupStateComplete, mgr.Status().State) } func TestPostSetupManager_Stop_WhileInProgress(t *testing.T) { - req := require.New(t) - mgr := newTestPostManager(t) mgr.opts.MaxFileSize = 4096 mgr.opts.NumUnits = mgr.cfg.MaxNumUnits + nodeID := types.RandomNodeID() + // Create data. - req.NoError(mgr.PrepareInitializer(context.Background(), mgr.opts)) + require.NoError(t, mgr.PrepareInitializer(context.Background(), mgr.opts, nodeID)) ctx, cancel := context.WithCancel(context.Background()) var eg errgroup.Group eg.Go(func() error { - return mgr.StartSession(ctx) + return mgr.StartSession(ctx, nodeID) }) // Verify the intermediate status. - req.Eventually(func() bool { + require.Eventually(t, func() bool { return mgr.Status().State == PostSetupStateInProgress }, 5*time.Second, 10*time.Millisecond) // Stop initialization. cancel() - req.ErrorIs(eg.Wait(), context.Canceled) + require.ErrorIs(t, eg.Wait(), context.Canceled) // Verify status. status := mgr.Status() - req.Equal(PostSetupStateStopped, status.State) - req.LessOrEqual(status.NumLabelsWritten, uint64(mgr.opts.NumUnits)*mgr.cfg.LabelsPerUnit) + require.Equal(t, PostSetupStateStopped, status.State) + require.LessOrEqual(t, status.NumLabelsWritten, uint64(mgr.opts.NumUnits)*mgr.cfg.LabelsPerUnit) // Continue to create data. - req.NoError(mgr.PrepareInitializer(context.Background(), mgr.opts)) - req.NoError(mgr.StartSession(context.Background())) + require.NoError(t, mgr.PrepareInitializer(context.Background(), mgr.opts, nodeID)) + require.NoError(t, mgr.StartSession(context.Background(), nodeID)) // Verify status. status = mgr.Status() - req.Equal(PostSetupStateComplete, status.State) - req.Equal(uint64(mgr.opts.NumUnits)*mgr.cfg.LabelsPerUnit, status.NumLabelsWritten) + require.Equal(t, PostSetupStateComplete, status.State) + require.Equal(t, uint64(mgr.opts.NumUnits)*mgr.cfg.LabelsPerUnit, status.NumLabelsWritten) } func TestPostSetupManager_findCommitmentAtx_UsesLatestAtx(t *testing.T) { mgr := newTestPostManager(t) + signer, err := signing.NewEdSigner() + require.NoError(t, err) challenge := types.NIPostChallenge{ PublishEpoch: 1, } atx := types.NewActivationTx(challenge, types.Address{}, nil, 2, nil) - require.NoError(t, SignAndFinalizeAtx(mgr.signer, atx)) + require.NoError(t, SignAndFinalizeAtx(signer, atx)) atx.SetEffectiveNumUnits(atx.NumUnits) atx.SetReceived(time.Now()) vAtx, err := atx.Verify(0, 1) @@ -301,16 +299,17 @@ func TestPostSetupManager_findCommitmentAtx_DefaultsToGoldenAtx(t *testing.T) { func TestPostSetupManager_getCommitmentAtx_getsCommitmentAtxFromPostMetadata(t *testing.T) { mgr := newTestPostManager(t) + nodeID := types.RandomNodeID() // write commitment atx to metadata commitmentAtx := types.RandomATXID() err := initialization.SaveMetadata(mgr.opts.DataDir, &shared.PostMetadata{ CommitmentAtxId: commitmentAtx.Bytes(), - NodeId: mgr.signer.NodeID().Bytes(), + NodeId: nodeID.Bytes(), }) require.NoError(t, err) - atxid, err := mgr.commitmentAtx(context.Background(), mgr.opts.DataDir) + atxid, err := mgr.commitmentAtx(context.Background(), mgr.opts.DataDir, nodeID) require.NoError(t, err) require.NotNil(t, atxid) require.Equal(t, commitmentAtx, atxid) @@ -318,19 +317,21 @@ func TestPostSetupManager_getCommitmentAtx_getsCommitmentAtxFromPostMetadata(t * func TestPostSetupManager_getCommitmentAtx_getsCommitmentAtxFromInitialAtx(t *testing.T) { mgr := newTestPostManager(t) + signer, err := signing.NewEdSigner() + require.NoError(t, err) // add an atx by the same node commitmentAtx := types.RandomATXID() atx := types.NewActivationTx(types.NIPostChallenge{}, types.Address{}, nil, 1, nil) atx.CommitmentATX = &commitmentAtx - require.NoError(t, SignAndFinalizeAtx(mgr.signer, atx)) + require.NoError(t, SignAndFinalizeAtx(signer, atx)) atx.SetEffectiveNumUnits(atx.NumUnits) atx.SetReceived(time.Now()) vAtx, err := atx.Verify(0, 1) require.NoError(t, err) require.NoError(t, atxs.Add(mgr.cdb, vAtx)) - atxid, err := mgr.commitmentAtx(context.Background(), mgr.opts.DataDir) + atxid, err := mgr.commitmentAtx(context.Background(), mgr.opts.DataDir, signer.NodeID()) require.NoError(t, err) require.Equal(t, commitmentAtx, atxid) } @@ -340,17 +341,12 @@ type testPostManager struct { opts PostSetupOpts - signer *signing.EdSigner - cdb *datastore.CachedDB + cdb *datastore.CachedDB } func newTestPostManager(tb testing.TB) *testPostManager { tb.Helper() - sig, err := signing.NewEdSigner() - require.NoError(tb, err) - id := sig.NodeID() - opts := DefaultPostSetupOpts() opts.DataDir = tb.TempDir() opts.ProviderID.SetUint32(initialization.CPUProviderID()) @@ -369,13 +365,12 @@ func newTestPostManager(tb testing.TB) *testPostManager { syncer.EXPECT().RegisterForATXSynced().AnyTimes().Return(synced) cdb := datastore.NewCachedDB(sql.InMemory(), logtest.New(tb)) - mgr, err := NewPostSetupManager(id, DefaultPostConfig(), zaptest.NewLogger(tb), cdb, goldenATXID, syncer, validator) + mgr, err := NewPostSetupManager(DefaultPostConfig(), zaptest.NewLogger(tb), cdb, goldenATXID, syncer, validator) require.NoError(tb, err) return &testPostManager{ PostSetupManager: mgr, opts: opts, - signer: sig, cdb: cdb, } } diff --git a/activation/post_verifier.go b/activation/post_verifier.go index fe3ab1a2d0..5443095eb5 100644 --- a/activation/post_verifier.go +++ b/activation/post_verifier.go @@ -112,9 +112,9 @@ func WithVerifyingOpts(opts PostProofVerifyingOpts) PostVerifierOpt { } } -func PrioritizedIDs(ids ...types.NodeID) PostVerifierOpt { +func WithPrioritizedID(id types.NodeID) PostVerifierOpt { return func(v *postVerifierOpts) { - v.prioritizedIds = ids + v.prioritizedIds = append(v.prioritizedIds, id) } } diff --git a/activation/validation_test.go b/activation/validation_test.go index 3277d88753..b73be25996 100644 --- a/activation/validation_test.go +++ b/activation/validation_test.go @@ -571,7 +571,6 @@ func TestValidateMerkleProof(t *testing.T) { } func TestVerifyChainDeps(t *testing.T) { - ctrl := gomock.NewController(t) db := sql.InMemory() ctx := context.Background() goldenATXID := types.ATXID{2, 3, 4} @@ -609,6 +608,7 @@ func TestVerifyChainDeps(t *testing.T) { vAtx.SetValidity(types.Unknown) require.NoError(t, atxs.Add(db, vAtx)) + ctrl := gomock.NewController(t) v := NewMockPostVerifier(ctrl) v.EXPECT().Verify(ctx, (*shared.Proof)(atx.NIPost.Post), gomock.Any(), gomock.Any()) @@ -633,6 +633,7 @@ func TestVerifyChainDeps(t *testing.T) { vAtx.SetValidity(types.Unknown) require.NoError(t, atxs.Add(db, vAtx)) + ctrl := gomock.NewController(t) v := NewMockPostVerifier(ctrl) v.EXPECT().Verify(ctx, (*shared.Proof)(atx.NIPost.Post), gomock.Any(), gomock.Any()) @@ -658,6 +659,7 @@ func TestVerifyChainDeps(t *testing.T) { vAtx.SetValidity(types.Unknown) require.NoError(t, atxs.Add(db, vAtx)) + ctrl := gomock.NewController(t) v := NewMockPostVerifier(ctrl) v.EXPECT().Verify(ctx, (*shared.Proof)(atx.NIPost.Post), gomock.Any(), gomock.Any()) validator := NewValidator(db, nil, DefaultPostConfig(), config.ScryptParams{}, v) @@ -681,6 +683,7 @@ func TestVerifyChainDeps(t *testing.T) { vAtx.SetValidity(types.Unknown) require.NoError(t, atxs.Add(db, vAtx)) + ctrl := gomock.NewController(t) v := NewMockPostVerifier(ctrl) validator := NewValidator(db, nil, DefaultPostConfig(), config.ScryptParams{}, v) err = validator.VerifyChain(ctx, vAtx.ID(), goldenATXID, VerifyChainOpts.WithTrustedID(signer.NodeID())) @@ -703,9 +706,11 @@ func TestVerifyChainDeps(t *testing.T) { vAtx.SetValidity(types.Unknown) require.NoError(t, atxs.Add(db, vAtx)) + ctrl := gomock.NewController(t) v := NewMockPostVerifier(ctrl) validator := NewValidator(db, nil, DefaultPostConfig(), config.ScryptParams{}, v) - err = validator.VerifyChain(ctx, vAtx.ID(), goldenATXID, VerifyChainOpts.AssumeValidBefore(time.Now())) + before := time.Now().Add(10 * time.Second) + err = validator.VerifyChain(ctx, vAtx.ID(), goldenATXID, VerifyChainOpts.AssumeValidBefore(before)) require.NoError(t, err) }) @@ -725,8 +730,9 @@ func TestVerifyChainDeps(t *testing.T) { vAtx.SetValidity(types.Unknown) require.NoError(t, atxs.Add(db, vAtx)) - expected := errors.New("post is invalid") + ctrl := gomock.NewController(t) v := NewMockPostVerifier(ctrl) + expected := errors.New("post is invalid") v.EXPECT().Verify(ctx, (*shared.Proof)(atx.NIPost.Post), gomock.Any(), gomock.Any()).Return(expected) validator := NewValidator(db, nil, DefaultPostConfig(), config.ScryptParams{}, v) err = validator.VerifyChain(ctx, vAtx.ID(), goldenATXID) diff --git a/api/grpcserver/grpcserver_test.go b/api/grpcserver/grpcserver_test.go index 071ed17d9e..b2f537b03c 100644 --- a/api/grpcserver/grpcserver_test.go +++ b/api/grpcserver/grpcserver_test.go @@ -554,11 +554,17 @@ type smesherServiceConn struct { postSupervisor *MockpostSupervisor } -func setupSmesherService(t *testing.T) (*smesherServiceConn, context.Context) { +func setupSmesherService(t *testing.T, id *types.NodeID) (*smesherServiceConn, context.Context) { ctrl, mockCtx := gomock.WithContext(context.Background(), t) smeshingProvider := activation.NewMockSmeshingProvider(ctrl) postSupervisor := NewMockpostSupervisor(ctrl) - svc := NewSmesherService(smeshingProvider, postSupervisor, 10*time.Millisecond, activation.DefaultPostSetupOpts()) + svc := NewSmesherService( + smeshingProvider, + postSupervisor, + 10*time.Millisecond, + id, + activation.DefaultPostSetupOpts(), + ) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) @@ -578,7 +584,7 @@ func setupSmesherService(t *testing.T) (*smesherServiceConn, context.Context) { func TestSmesherService(t *testing.T) { t.Run("IsSmeshing", func(t *testing.T) { t.Parallel() - c, ctx := setupSmesherService(t) + c, ctx := setupSmesherService(t, nil) c.smeshingProvider.EXPECT().Smeshing().Return(false) res, err := c.IsSmeshing(ctx, &emptypb.Empty{}) require.NoError(t, err) @@ -587,7 +593,7 @@ func TestSmesherService(t *testing.T) { t.Run("StartSmeshingMissingArgs", func(t *testing.T) { t.Parallel() - c, ctx := setupSmesherService(t) + c, ctx := setupSmesherService(t, nil) _, err := c.StartSmeshing(ctx, &pb.StartSmeshingRequest{}) require.Equal(t, codes.InvalidArgument, status.Code(err)) }) @@ -600,8 +606,9 @@ func TestSmesherService(t *testing.T) { opts.MaxFileSize = 1024 coinbase := &pb.AccountId{Address: addr1.String()} + nodeID := types.RandomNodeID() - c, ctx := setupSmesherService(t) + c, ctx := setupSmesherService(t, &nodeID) c.smeshingProvider.EXPECT().StartSmeshing(gomock.Any()).Return(nil) c.postSupervisor.EXPECT().Start(gomock.All( gomock.Cond(func(postOpts any) bool { return postOpts.(activation.PostSetupOpts).DataDir == opts.DataDir }), @@ -611,7 +618,7 @@ func TestSmesherService(t *testing.T) { gomock.Cond( func(postOpts any) bool { return postOpts.(activation.PostSetupOpts).MaxFileSize == opts.MaxFileSize }, ), - )).Return(nil) + ), nodeID).Return(nil) res, err := c.StartSmeshing(ctx, &pb.StartSmeshingRequest{ Opts: opts, Coinbase: coinbase, @@ -620,9 +627,28 @@ func TestSmesherService(t *testing.T) { require.Equal(t, int32(code.Code_OK), res.Status.Code) }) + t.Run("StartSmeshingMultiSetup", func(t *testing.T) { + t.Parallel() + opts := &pb.PostSetupOpts{} + opts.DataDir = t.TempDir() + opts.NumUnits = 1 + opts.MaxFileSize = 1024 + + coinbase := &pb.AccountId{Address: addr1.String()} + + c, ctx := setupSmesherService(t, nil) // in multi smeshing setup the node id is nil and start smeshing should fail + res, err := c.StartSmeshing(ctx, &pb.StartSmeshingRequest{ + Opts: opts, + Coinbase: coinbase, + }) + require.Equal(t, codes.FailedPrecondition, status.Code(err)) + require.ErrorContains(t, err, "node is not configured for supervised smeshing") + require.Nil(t, res) + }) + t.Run("StopSmeshing", func(t *testing.T) { t.Parallel() - c, ctx := setupSmesherService(t) + c, ctx := setupSmesherService(t, nil) c.smeshingProvider.EXPECT().StopSmeshing(gomock.Any()).Return(nil) c.postSupervisor.EXPECT().Stop(false).Return(nil) res, err := c.StopSmeshing(ctx, &pb.StopSmeshingRequest{}) @@ -632,7 +658,7 @@ func TestSmesherService(t *testing.T) { t.Run("SmesherIDs", func(t *testing.T) { t.Parallel() - c, ctx := setupSmesherService(t) + c, ctx := setupSmesherService(t, nil) nodeId := types.RandomNodeID() c.smeshingProvider.EXPECT().SmesherIDs().Return([]types.NodeID{nodeId}) res, err := c.SmesherIDs(ctx, &emptypb.Empty{}) @@ -643,7 +669,7 @@ func TestSmesherService(t *testing.T) { t.Run("SetCoinbaseMissingArgs", func(t *testing.T) { t.Parallel() - c, ctx := setupSmesherService(t) + c, ctx := setupSmesherService(t, nil) _, err := c.SetCoinbase(ctx, &pb.SetCoinbaseRequest{}) require.Error(t, err) statusCode := status.Code(err) @@ -652,7 +678,7 @@ func TestSmesherService(t *testing.T) { t.Run("SetCoinbase", func(t *testing.T) { t.Parallel() - c, ctx := setupSmesherService(t) + c, ctx := setupSmesherService(t, nil) c.smeshingProvider.EXPECT().SetCoinbase(addr1) res, err := c.SetCoinbase(ctx, &pb.SetCoinbaseRequest{ Id: &pb.AccountId{Address: addr1.String()}, @@ -663,7 +689,7 @@ func TestSmesherService(t *testing.T) { t.Run("Coinbase", func(t *testing.T) { t.Parallel() - c, ctx := setupSmesherService(t) + c, ctx := setupSmesherService(t, nil) c.smeshingProvider.EXPECT().Coinbase().Return(addr1) res, err := c.Coinbase(ctx, &emptypb.Empty{}) require.NoError(t, err) @@ -674,7 +700,7 @@ func TestSmesherService(t *testing.T) { t.Run("MinGas", func(t *testing.T) { t.Parallel() - c, ctx := setupSmesherService(t) + c, ctx := setupSmesherService(t, nil) _, err := c.MinGas(ctx, &emptypb.Empty{}) require.Error(t, err) statusCode := status.Code(err) @@ -683,7 +709,7 @@ func TestSmesherService(t *testing.T) { t.Run("SetMinGas", func(t *testing.T) { t.Parallel() - c, ctx := setupSmesherService(t) + c, ctx := setupSmesherService(t, nil) _, err := c.SetMinGas(ctx, &pb.SetMinGasRequest{}) require.Error(t, err) statusCode := status.Code(err) @@ -692,7 +718,7 @@ func TestSmesherService(t *testing.T) { t.Run("PostSetupComputeProviders", func(t *testing.T) { t.Parallel() - c, ctx := setupSmesherService(t) + c, ctx := setupSmesherService(t, nil) c.postSupervisor.EXPECT().Providers().Return(nil, nil) _, err := c.PostSetupProviders(ctx, &pb.PostSetupProvidersRequest{Benchmark: false}) require.NoError(t, err) @@ -700,7 +726,7 @@ func TestSmesherService(t *testing.T) { t.Run("PostSetupStatusStream", func(t *testing.T) { t.Parallel() - c, ctx := setupSmesherService(t) + c, ctx := setupSmesherService(t, nil) c.postSupervisor.EXPECT().Status().Return(&activation.PostSetupStatus{}).AnyTimes() ctx, cancel := context.WithCancel(ctx) diff --git a/api/grpcserver/interface.go b/api/grpcserver/interface.go index ebc41f7022..ce3c3f6fb9 100644 --- a/api/grpcserver/interface.go +++ b/api/grpcserver/interface.go @@ -57,7 +57,7 @@ type atxProvider interface { } type postSupervisor interface { - Start(opts activation.PostSetupOpts) error + Start(opts activation.PostSetupOpts, id types.NodeID) error Stop(deleteFiles bool) error Config() activation.PostConfig diff --git a/api/grpcserver/mocks.go b/api/grpcserver/mocks.go index 40f4256cd7..01ce9bbdcf 100644 --- a/api/grpcserver/mocks.go +++ b/api/grpcserver/mocks.go @@ -1089,17 +1089,17 @@ func (c *MockpostSupervisorProvidersCall) DoAndReturn(f func() ([]activation.Pos } // Start mocks base method. -func (m *MockpostSupervisor) Start(opts activation.PostSetupOpts) error { +func (m *MockpostSupervisor) Start(opts activation.PostSetupOpts, id types.NodeID) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Start", opts) + ret := m.ctrl.Call(m, "Start", opts, id) ret0, _ := ret[0].(error) return ret0 } // Start indicates an expected call of Start. -func (mr *MockpostSupervisorMockRecorder) Start(opts any) *MockpostSupervisorStartCall { +func (mr *MockpostSupervisorMockRecorder) Start(opts, id any) *MockpostSupervisorStartCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockpostSupervisor)(nil).Start), opts) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockpostSupervisor)(nil).Start), opts, id) return &MockpostSupervisorStartCall{Call: call} } @@ -1115,13 +1115,13 @@ func (c *MockpostSupervisorStartCall) Return(arg0 error) *MockpostSupervisorStar } // Do rewrite *gomock.Call.Do -func (c *MockpostSupervisorStartCall) Do(f func(activation.PostSetupOpts) error) *MockpostSupervisorStartCall { +func (c *MockpostSupervisorStartCall) Do(f func(activation.PostSetupOpts, types.NodeID) error) *MockpostSupervisorStartCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockpostSupervisorStartCall) DoAndReturn(f func(activation.PostSetupOpts) error) *MockpostSupervisorStartCall { +func (c *MockpostSupervisorStartCall) DoAndReturn(f func(activation.PostSetupOpts, types.NodeID) error) *MockpostSupervisorStartCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/api/grpcserver/post_service_test.go b/api/grpcserver/post_service_test.go index dae89fda01..8697f35150 100644 --- a/api/grpcserver/post_service_test.go +++ b/api/grpcserver/post_service_test.go @@ -40,7 +40,6 @@ func launchPostSupervisor( sig, err := signing.NewEdSigner() require.NoError(tb, err) - id := sig.NodeID() goldenATXID := types.RandomATXID() validator := activation.NewMocknipostValidator(gomock.NewController(tb)) @@ -55,15 +54,15 @@ func launchPostSupervisor( return ch }) cdb := datastore.NewCachedDB(sql.InMemory(), logtest.New(tb)) - mgr, err := activation.NewPostSetupManager(id, postCfg, log.Named("post manager"), cdb, goldenATXID, syncer, validator) + mgr, err := activation.NewPostSetupManager(postCfg, log.Named("post manager"), cdb, goldenATXID, syncer, validator) require.NoError(tb, err) // start post supervisor ps, err := activation.NewPostSupervisor(log, serviceCfg, postCfg, provingOpts, mgr) require.NoError(tb, err) require.NotNil(tb, ps) - require.NoError(tb, ps.Start(postOpts)) - return id, func() { assert.NoError(tb, ps.Stop(false)) } + require.NoError(tb, ps.Start(postOpts, sig.NodeID())) + return sig.NodeID(), func() { assert.NoError(tb, ps.Stop(false)) } } func launchPostSupervisorTLS( @@ -84,7 +83,6 @@ func launchPostSupervisorTLS( sig, err := signing.NewEdSigner() require.NoError(tb, err) - id := sig.NodeID() goldenATXID := types.RandomATXID() validator := activation.NewMocknipostValidator(gomock.NewController(tb)) @@ -98,14 +96,14 @@ func launchPostSupervisorTLS( return ch }) cdb := datastore.NewCachedDB(sql.InMemory(), logtest.New(tb)) - mgr, err := activation.NewPostSetupManager(id, postCfg, log.Named("post manager"), cdb, goldenATXID, syncer, validator) + mgr, err := activation.NewPostSetupManager(postCfg, log.Named("post manager"), cdb, goldenATXID, syncer, validator) require.NoError(tb, err) ps, err := activation.NewPostSupervisor(log, serviceCfg, postCfg, provingOpts, mgr) require.NoError(tb, err) require.NotNil(tb, ps) - require.NoError(tb, ps.Start(postOpts)) - return id, func() { assert.NoError(tb, ps.Stop(false)) } + require.NoError(tb, ps.Start(postOpts, sig.NodeID())) + return sig.NodeID(), func() { assert.NoError(tb, ps.Stop(false)) } } func Test_GenerateProof(t *testing.T) { diff --git a/api/grpcserver/smesher_service.go b/api/grpcserver/smesher_service.go index e6df58f953..86d91d2086 100644 --- a/api/grpcserver/smesher_service.go +++ b/api/grpcserver/smesher_service.go @@ -27,6 +27,7 @@ type SmesherService struct { postSupervisor postSupervisor streamInterval time.Duration + nodeID *types.NodeID postOpts activation.PostSetupOpts } @@ -49,12 +50,14 @@ func NewSmesherService( smeshing activation.SmeshingProvider, postSupervisor postSupervisor, streamInterval time.Duration, + nodeID *types.NodeID, postOpts activation.PostSetupOpts, ) *SmesherService { return &SmesherService{ smeshingProvider: smeshing, postSupervisor: postSupervisor, streamInterval: streamInterval, + nodeID: nodeID, postOpts: postOpts, } } @@ -81,7 +84,10 @@ func (s SmesherService) StartSmeshing( if err != nil { status.Error(codes.InvalidArgument, err.Error()) } - if err := s.postSupervisor.Start(opts); err != nil { + if s.nodeID == nil { + return nil, status.Errorf(codes.FailedPrecondition, "node is not configured for supervised smeshing") + } + if err := s.postSupervisor.Start(opts, *s.nodeID); err != nil { ctxzap.Error(ctx, "failed to start post supervisor", zap.Error(err)) return nil, status.Error(codes.Internal, fmt.Sprintf("failed to start post supervisor: %v", err)) } diff --git a/api/grpcserver/smesher_service_test.go b/api/grpcserver/smesher_service_test.go index 9c8b4a8ff9..b2acbe8c81 100644 --- a/api/grpcserver/smesher_service_test.go +++ b/api/grpcserver/smesher_service_test.go @@ -27,6 +27,7 @@ func TestPostConfig(t *testing.T) { smeshingProvider, postSupervisor, time.Second, + nil, activation.DefaultPostSetupOpts(), ) @@ -53,10 +54,12 @@ func TestStartSmeshingPassesCorrectSmeshingOpts(t *testing.T) { ctrl := gomock.NewController(t) smeshingProvider := activation.NewMockSmeshingProvider(ctrl) postSupervisor := grpcserver.NewMockpostSupervisor(ctrl) + nodeID := types.RandomNodeID() svc := grpcserver.NewSmesherService( smeshingProvider, postSupervisor, time.Second, + &nodeID, activation.DefaultPostSetupOpts(), ) @@ -73,7 +76,7 @@ func TestStartSmeshingPassesCorrectSmeshingOpts(t *testing.T) { ComputeBatchSize: config.DefaultComputeBatchSize, } opts.ProviderID.SetUint32(providerID) - postSupervisor.EXPECT().Start(opts).Return(nil) + postSupervisor.EXPECT().Start(opts, nodeID).Return(nil) smeshingProvider.EXPECT().StartSmeshing(addr).Return(nil) _, err = svc.StartSmeshing(context.Background(), &pb.StartSmeshingRequest{ @@ -89,6 +92,44 @@ func TestStartSmeshingPassesCorrectSmeshingOpts(t *testing.T) { require.NoError(t, err) } +func TestStartSmeshing_ErrorOnMultiSmeshingSetup(t *testing.T) { + ctrl := gomock.NewController(t) + smeshingProvider := activation.NewMockSmeshingProvider(ctrl) + postSupervisor := grpcserver.NewMockpostSupervisor(ctrl) + svc := grpcserver.NewSmesherService( + smeshingProvider, + postSupervisor, + time.Second, + nil, // no nodeID in multi smesher setup + activation.DefaultPostSetupOpts(), + ) + + types.SetNetworkHRP("stest") + providerID := uint32(7) + opts := activation.PostSetupOpts{ + DataDir: "data-dir", + NumUnits: 1, + MaxFileSize: 1024, + Throttle: true, + Scrypt: config.DefaultLabelParams(), + ComputeBatchSize: config.DefaultComputeBatchSize, + } + opts.ProviderID.SetUint32(providerID) + + _, err := svc.StartSmeshing(context.Background(), &pb.StartSmeshingRequest{ + Coinbase: &pb.AccountId{Address: "stest1qqqqqqrs60l66w5uksxzmaznwq6xnhqfv56c28qlkm4a5"}, + Opts: &pb.PostSetupOpts{ + DataDir: "data-dir", + NumUnits: 1, + MaxFileSize: 1024, + ProviderId: &providerID, + Throttle: true, + }, + }) + require.Equal(t, codes.FailedPrecondition, status.Code(err)) + require.ErrorContains(t, err, "node is not configured for supervised smeshing") +} + func TestSmesherService_PostSetupProviders(t *testing.T) { ctrl := gomock.NewController(t) smeshingProvider := activation.NewMockSmeshingProvider(ctrl) @@ -97,6 +138,7 @@ func TestSmesherService_PostSetupProviders(t *testing.T) { smeshingProvider, postSupervisor, time.Second, + nil, activation.DefaultPostSetupOpts(), ) @@ -143,6 +185,7 @@ func TestSmesherService_PostSetupStatus(t *testing.T) { smeshingProvider, postSupervisor, time.Second, + nil, activation.DefaultPostSetupOpts(), ) @@ -166,6 +209,7 @@ func TestSmesherService_PostSetupStatus(t *testing.T) { smeshingProvider, postSupervisor, time.Second, + nil, activation.DefaultPostSetupOpts(), ) @@ -202,6 +246,7 @@ func TestSmesherService_PostSetupStatus(t *testing.T) { smeshingProvider, postSupervisor, time.Second, + nil, activation.DefaultPostSetupOpts(), ) @@ -239,6 +284,7 @@ func TestSmesherService_SmesherID(t *testing.T) { smeshingProvider, postSupervisor, time.Second, + nil, activation.DefaultPostSetupOpts(), ) diff --git a/beacon/beacon.go b/beacon/beacon.go index 43df444621..a29f412c11 100644 --- a/beacon/beacon.go +++ b/beacon/beacon.go @@ -135,16 +135,16 @@ func New( return pd } -func (pd *ProtocolDriver) Register(s *signing.EdSigner) { +func (pd *ProtocolDriver) Register(sig *signing.EdSigner) { pd.mu.Lock() defer pd.mu.Unlock() - if _, exists := pd.signers[s.NodeID()]; exists { - pd.logger.With().Error("signing key already registered", log.ShortStringer("id", s.NodeID())) + if _, exists := pd.signers[sig.NodeID()]; exists { + pd.logger.With().Error("signing key already registered", log.ShortStringer("id", sig.NodeID())) return } - pd.logger.With().Info("registered signing key", log.ShortStringer("id", s.NodeID())) - pd.signers[s.NodeID()] = s + pd.logger.With().Info("registered signing key", log.ShortStringer("id", sig.NodeID())) + pd.signers[sig.NodeID()] = sig } type participant struct { diff --git a/blocks/certifier.go b/blocks/certifier.go index fbc95ef324..ea6689c0d6 100644 --- a/blocks/certifier.go +++ b/blocks/certifier.go @@ -130,16 +130,16 @@ func NewCertifier( return c } -func (c *Certifier) Register(s *signing.EdSigner) { +func (c *Certifier) Register(sig *signing.EdSigner) { c.mu.Lock() defer c.mu.Unlock() - if _, exists := c.signers[s.NodeID()]; exists { - c.logger.With().Error("signing key already registered", log.ShortStringer("id", s.NodeID())) + if _, exists := c.signers[sig.NodeID()]; exists { + c.logger.With().Error("signing key already registered", log.ShortStringer("id", sig.NodeID())) return } - c.logger.With().Info("registered signing key", log.ShortStringer("id", s.NodeID())) - c.signers[s.NodeID()] = s + c.logger.With().Info("registered signing key", log.ShortStringer("id", sig.NodeID())) + c.signers[sig.NodeID()] = sig } // Start starts the background goroutine for periodic pruning. diff --git a/checkpoint/recovery.go b/checkpoint/recovery.go index 3d4ee5e13d..63d2bcb292 100644 --- a/checkpoint/recovery.go +++ b/checkpoint/recovery.go @@ -1,14 +1,17 @@ package checkpoint import ( + "bytes" "context" "encoding/json" "errors" "fmt" "net/url" "path/filepath" + "slices" "github.com/spf13/afero" + "golang.org/x/exp/maps" "github.com/spacemeshos/go-spacemesh/bootstrap" "github.com/spacemeshos/go-spacemesh/codec" @@ -45,7 +48,7 @@ type RecoverConfig struct { DbFile string LocalDbFile string PreserveOwnAtx bool - NodeID types.NodeID + NodeIDs []types.NodeID Uri string Restore types.LayerID } @@ -145,14 +148,14 @@ func RecoverWithDb( return nil, fmt.Errorf("remove old bootstrap data: %w", err) } logger.With().Info("recover from uri", log.String("uri", cfg.Uri)) - cpfile, err := copyToLocalFile(ctx, logger, fs, cfg.DataDir, cfg.Uri, cfg.Restore) + cpFile, err := copyToLocalFile(ctx, logger, fs, cfg.DataDir, cfg.Uri, cfg.Restore) if err != nil { return nil, err } - return recoverFromLocalFile(ctx, logger, db, localDB, fs, cfg, cpfile) + return recoverFromLocalFile(ctx, logger, db, localDB, fs, cfg, cpFile) } -type recoverydata struct { +type recoveryData struct { accounts []*types.Account atxs []*atxs.CheckpointAtx } @@ -176,19 +179,47 @@ func recoverFromLocalFile( log.Int("num accounts", len(data.accounts)), log.Int("num atxs", len(data.atxs)), ) - deps, proofs, err := collectOwnAtxDeps(logger, db, localDB, cfg, data) - if err != nil { - logger.With().Error("failed to collect deps for own atx", log.Err(err)) - // continue to recover from checkpoint despite failure to preserve own atx - } else if len(deps) > 0 { - logger.With().Info("collected own atx deps", - log.Context(ctx), - log.Int("own atx deps", len(deps)), - ) + deps := make(map[types.ATXID]*types.VerifiedActivationTx) + proofs := make(map[types.PoetProofRef]*types.PoetProofMessage) + if cfg.PreserveOwnAtx { + for _, nodeID := range cfg.NodeIDs { + nodeDeps, nodeProofs, err := collectOwnAtxDeps(logger, db, localDB, nodeID, cfg.GoldenAtx, data) + if err != nil { + logger.With().Error("failed to collect deps for own atx", + nodeID, + log.Err(err), + ) + // continue to recover from checkpoint despite failure to preserve own atx + continue + } + + logger.With().Info("collected own atx deps", + log.Context(ctx), + nodeID, + log.Int("own atx deps", len(nodeDeps)), + ) + maps.Copy(deps, nodeDeps) + maps.Copy(proofs, nodeProofs) + } } if err := db.Close(); err != nil { return nil, fmt.Errorf("close old db: %w", err) } + allDeps := maps.Values(deps) + // sort ATXs them by publishEpoch and then by ID + slices.SortFunc(allDeps, func(i, j *types.VerifiedActivationTx) int { + return bytes.Compare(i.ID().Bytes(), j.ID().Bytes()) + }) + slices.SortStableFunc(allDeps, func(i, j *types.VerifiedActivationTx) int { + return int(i.PublishEpoch) - int(j.PublishEpoch) + }) + allProofs := maps.Values(proofs) + // sort PoET proofs by ref + slices.SortFunc(allProofs, func(i, j *types.PoetProofMessage) int { + iRef, _ := i.Ref() + jRef, _ := j.Ref() + return bytes.Compare(iRef[:], jRef[:]) + }) // all is ready. backup the old data and create new. backupDir, err := backupOldDb(fs, cfg.DataDir, cfg.DbFile) @@ -200,17 +231,17 @@ func recoverFromLocalFile( log.String("backup dir", backupDir), ) - newdb, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) if err != nil { return nil, fmt.Errorf("open sqlite db %w", err) } - defer newdb.Close() + defer newDB.Close() logger.With().Info("populating new database", log.Context(ctx), log.Int("num accounts", len(data.accounts)), log.Int("num atxs", len(data.atxs)), ) - if err = newdb.WithTx(ctx, func(tx *sql.Tx) error { + if err = newDB.WithTx(ctx, func(tx *sql.Tx) error { for _, acct := range data.accounts { if err = accounts.Update(tx, acct); err != nil { return fmt.Errorf("restore account snapshot: %w", err) @@ -222,18 +253,18 @@ func recoverFromLocalFile( log.Uint64("balance", acct.Balance), ) } - for _, catx := range data.atxs { - if err = atxs.AddCheckpointed(tx, catx); err != nil { - return fmt.Errorf("add checkpoint atx %s: %w", catx.ID.String(), err) + for _, cAtx := range data.atxs { + if err = atxs.AddCheckpointed(tx, cAtx); err != nil { + return fmt.Errorf("add checkpoint atx %s: %w", cAtx.ID.String(), err) } logger.With().Info("checkpoint atx saved", log.Context(ctx), - catx.ID, - catx.SmesherID, + cAtx.ID, + cAtx.SmesherID, ) } if err = recovery.SetCheckpoint(tx, cfg.Restore); err != nil { - return fmt.Errorf("save checkppoint info: %w", err) + return fmt.Errorf("save checkpoint info: %w", err) } return nil }); err != nil { @@ -248,18 +279,18 @@ func recoverFromLocalFile( types.GetEffectiveGenesis(), ) var preserve *PreservedData - if len(deps) > 0 { - preserve = &PreservedData{Deps: deps, Proofs: proofs} + if len(allDeps) > 0 { + preserve = &PreservedData{Deps: allDeps, Proofs: allProofs} } return preserve, nil } -func checkpointData(fs afero.Fs, file string, newGenesis types.LayerID) (*recoverydata, error) { +func checkpointData(fs afero.Fs, file string, newGenesis types.LayerID) (*recoveryData, error) { data, err := afero.ReadFile(fs, file) if err != nil { return nil, fmt.Errorf("%w: read recovery file %v", err, file) } - if err = ValidateSchema(data); err != nil { + if err := ValidateSchema(data); err != nil { return nil, err } var checkpoint types.Checkpoint @@ -288,20 +319,20 @@ func checkpointData(fs afero.Fs, file string, newGenesis types.LayerID) (*recove } allAtxs := make([]*atxs.CheckpointAtx, 0, len(checkpoint.Data.Atxs)) for _, atx := range checkpoint.Data.Atxs { - var catx atxs.CheckpointAtx - catx.ID = types.ATXID(types.BytesToHash(atx.ID)) - catx.Epoch = types.EpochID(atx.Epoch) - catx.CommitmentATX = types.ATXID(types.BytesToHash(atx.CommitmentAtx)) - catx.SmesherID = types.BytesToNodeID(atx.PublicKey) - catx.NumUnits = atx.NumUnits - catx.VRFNonce = types.VRFPostIndex(atx.VrfNonce) - catx.BaseTickHeight = atx.BaseTickHeight - catx.TickCount = atx.TickCount - catx.Sequence = atx.Sequence - copy(catx.Coinbase[:], atx.Coinbase) - allAtxs = append(allAtxs, &catx) - } - return &recoverydata{ + var cAtx atxs.CheckpointAtx + cAtx.ID = types.ATXID(types.BytesToHash(atx.ID)) + cAtx.Epoch = types.EpochID(atx.Epoch) + cAtx.CommitmentATX = types.ATXID(types.BytesToHash(atx.CommitmentAtx)) + cAtx.SmesherID = types.BytesToNodeID(atx.PublicKey) + cAtx.NumUnits = atx.NumUnits + cAtx.VRFNonce = types.VRFPostIndex(atx.VrfNonce) + cAtx.BaseTickHeight = atx.BaseTickHeight + cAtx.TickCount = atx.TickCount + cAtx.Sequence = atx.Sequence + copy(cAtx.Coinbase[:], atx.Coinbase) + allAtxs = append(allAtxs, &cAtx) + } + return &recoveryData{ accounts: allAccts, atxs: allAtxs, }, nil @@ -311,13 +342,11 @@ func collectOwnAtxDeps( logger log.Log, db *sql.Database, localDB *localsql.Database, - cfg *RecoverConfig, - data *recoverydata, -) ([]*types.VerifiedActivationTx, []*types.PoetProofMessage, error) { - if !cfg.PreserveOwnAtx { - return nil, nil, nil - } - atxid, err := atxs.GetLastIDByNodeID(db, cfg.NodeID) + nodeID types.NodeID, + goldenATX types.ATXID, + data *recoveryData, +) (map[types.ATXID]*types.VerifiedActivationTx, map[types.PoetProofRef]*types.PoetProofMessage, error) { + atxid, err := atxs.GetLastIDByNodeID(db, nodeID) if err != nil && !errors.Is(err, sql.ErrNotFound) { return nil, nil, fmt.Errorf("query own last atx id: %w", err) } @@ -329,8 +358,8 @@ func collectOwnAtxDeps( own = true } - // check for if miner is building any atx - nipostCh, _ := nipost.Challenge(localDB, cfg.NodeID) + // check for if smesher is building any atx + nipostCh, _ := nipost.Challenge(localDB, nodeID) if ref == types.EmptyATXID { if nipostCh == nil { return nil, nil, nil @@ -340,36 +369,34 @@ func collectOwnAtxDeps( } } - all := map[types.ATXID]struct{}{cfg.GoldenAtx: {}, types.EmptyATXID: {}} - for _, catx := range data.atxs { - all[catx.ID] = struct{}{} + all := map[types.ATXID]struct{}{goldenATX: {}, types.EmptyATXID: {}} + for _, cAtx := range data.atxs { + all[cAtx.ID] = struct{}{} } var ( - deps []*types.VerifiedActivationTx - proofs []*types.PoetProofMessage + deps map[types.ATXID]*types.VerifiedActivationTx + proofs map[types.PoetProofRef]*types.PoetProofMessage ) if ref != types.EmptyATXID { logger.With().Info("collecting atx and deps", ref, log.Bool("own", own), ) - deps, proofs, err = collectDeps(db, cfg.GoldenAtx, ref, all) + deps, proofs, err = collectDeps(db, goldenATX, ref, all) if err != nil { return nil, nil, err } } if nipostCh != nil { - logger.With().Info("collecting pending atx and deps", - log.Object("nipost", nipostCh), - ) + logger.With().Info("collecting pending atx and deps", log.Object("nipost", nipostCh)) // any previous atx in nipost should already be captured earlier // we only care about positioning atx here - deps2, proofs2, err := collectDeps(db, cfg.GoldenAtx, nipostCh.PositioningATX, all) + deps2, proofs2, err := collectDeps(db, goldenATX, nipostCh.PositioningATX, all) if err != nil { return nil, nil, fmt.Errorf("deps from nipost positioning atx (%v): %w", nipostCh.PositioningATX, err) } - deps = append(deps, deps2...) - proofs = append(proofs, proofs2...) + maps.Copy(deps, deps2) + maps.Copy(proofs, proofs2) } return deps, proofs, nil } @@ -379,9 +406,9 @@ func collectDeps( goldenAtxId types.ATXID, ref types.ATXID, all map[types.ATXID]struct{}, -) ([]*types.VerifiedActivationTx, []*types.PoetProofMessage, error) { - var deps []*types.VerifiedActivationTx - if err := collect(db, goldenAtxId, ref, all, &deps); err != nil { +) (map[types.ATXID]*types.VerifiedActivationTx, map[types.PoetProofRef]*types.PoetProofMessage, error) { + deps := make(map[types.ATXID]*types.VerifiedActivationTx) + if err := collect(db, goldenAtxId, ref, all, deps); err != nil { return nil, nil, err } proofs, err := poetProofs(db, deps) @@ -396,7 +423,7 @@ func collect( goldenAtxID types.ATXID, ref types.ATXID, all map[types.ATXID]struct{}, - deps *[]*types.VerifiedActivationTx, + deps map[types.ATXID]*types.VerifiedActivationTx, ) error { if _, ok := all[ref]; ok { return nil @@ -427,14 +454,17 @@ func collect( if err = collect(db, goldenAtxID, atx.PositioningATX, all, deps); err != nil { return err } - *deps = append(*deps, atx) + deps[ref] = atx all[ref] = struct{}{} return nil } -func poetProofs(db *sql.Database, vatxs []*types.VerifiedActivationTx) ([]*types.PoetProofMessage, error) { - var proofs []*types.PoetProofMessage - for _, vatx := range vatxs { +func poetProofs( + db *sql.Database, + vAtxs map[types.ATXID]*types.VerifiedActivationTx, +) (map[types.PoetProofRef]*types.PoetProofMessage, error) { + proofs := make(map[types.PoetProofRef]*types.PoetProofMessage, len(vAtxs)) + for _, vatx := range vAtxs { proof, err := poets.Get(db, types.PoetProofRef(vatx.GetPoetProofRef())) if err != nil { return nil, fmt.Errorf("get poet proof (%v): %w", vatx.ID(), err) @@ -443,7 +473,11 @@ func poetProofs(db *sql.Database, vatxs []*types.VerifiedActivationTx) ([]*types if err := codec.Decode(proof, &msg); err != nil { return nil, fmt.Errorf("decode poet proof (%v): %w", vatx.ID(), err) } - proofs = append(proofs, &msg) + ref, err := msg.Ref() + if err != nil { + return nil, fmt.Errorf("get poet proof ref (%v): %w", vatx.ID(), err) + } + proofs[ref] = &msg } return proofs, nil } diff --git a/checkpoint/recovery_test.go b/checkpoint/recovery_test.go index 2871c41f09..7f5329df3c 100644 --- a/checkpoint/recovery_test.go +++ b/checkpoint/recovery_test.go @@ -14,6 +14,7 @@ import ( "github.com/spacemeshos/poet/shared" "github.com/spf13/afero" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -40,27 +41,27 @@ const recoverLayer uint32 = 18 var goldenAtx = types.ATXID{1} -func atxequal( +func atxEqual( tb testing.TB, - satx types.AtxSnapshot, - vatx *types.VerifiedActivationTx, + sAtx types.AtxSnapshot, + vAtx *types.VerifiedActivationTx, commitAtx types.ATXID, vrfnonce types.VRFPostIndex, ) { - require.True(tb, bytes.Equal(satx.ID, vatx.ID().Bytes())) - require.EqualValues(tb, satx.Epoch, vatx.PublishEpoch) - require.True(tb, bytes.Equal(satx.CommitmentAtx, commitAtx.Bytes())) - require.EqualValues(tb, satx.VrfNonce, vrfnonce) - require.Equal(tb, satx.NumUnits, vatx.NumUnits) - require.Equal(tb, satx.BaseTickHeight, vatx.BaseTickHeight()) - require.Equal(tb, satx.TickCount, vatx.TickCount()) - require.True(tb, bytes.Equal(satx.PublicKey, vatx.SmesherID.Bytes())) - require.Equal(tb, satx.Sequence, vatx.Sequence) - require.True(tb, bytes.Equal(satx.Coinbase, vatx.Coinbase.Bytes())) - require.True(tb, vatx.Golden()) + require.True(tb, bytes.Equal(sAtx.ID, vAtx.ID().Bytes())) + require.EqualValues(tb, sAtx.Epoch, vAtx.PublishEpoch) + require.True(tb, bytes.Equal(sAtx.CommitmentAtx, commitAtx.Bytes())) + require.EqualValues(tb, sAtx.VrfNonce, vrfnonce) + require.Equal(tb, sAtx.NumUnits, vAtx.NumUnits) + require.Equal(tb, sAtx.BaseTickHeight, vAtx.BaseTickHeight()) + require.Equal(tb, sAtx.TickCount, vAtx.TickCount()) + require.True(tb, bytes.Equal(sAtx.PublicKey, vAtx.SmesherID.Bytes())) + require.Equal(tb, sAtx.Sequence, vAtx.Sequence) + require.True(tb, bytes.Equal(sAtx.Coinbase, vAtx.Coinbase.Bytes())) + require.True(tb, vAtx.Golden()) } -func accountequal(tb testing.TB, cacct types.AccountSnapshot, acct *types.Account) { +func accountEqual(tb testing.TB, cacct types.AccountSnapshot, acct *types.Account) { require.True(tb, bytes.Equal(cacct.Address, acct.Address.Bytes())) require.Equal(tb, cacct.Balance, acct.Balance) require.Equal(tb, cacct.Nonce, acct.NextNonce) @@ -75,7 +76,7 @@ func accountequal(tb testing.TB, cacct types.AccountSnapshot, acct *types.Accoun func verifyDbContent(tb testing.TB, db *sql.Database) { var expected types.Checkpoint - require.NoError(tb, json.Unmarshal([]byte(checkpointdata), &expected)) + require.NoError(tb, json.Unmarshal([]byte(checkpointData), &expected)) expAtx := map[types.ATXID]types.AtxSnapshot{} for _, satx := range expected.Data.Atxs { expAtx[types.ATXID(types.BytesToHash(satx.ID))] = satx @@ -92,12 +93,12 @@ func verifyDbContent(tb testing.TB, db *sql.Database) { for _, id := range allIds { vatx, err := atxs.Get(db, id) require.NoError(tb, err) - commitatx, err := atxs.CommitmentATX(db, vatx.SmesherID) + commitAtx, err := atxs.CommitmentATX(db, vatx.SmesherID) require.NoError(tb, err) - vrfnonce, err := atxs.VRFNonce(db, vatx.SmesherID, vatx.PublishEpoch+1) + vrfNonce, err := atxs.VRFNonce(db, vatx.SmesherID, vatx.PublishEpoch+1) require.NoError(tb, err) if _, ok := expAtx[id]; ok { - atxequal(tb, expAtx[id], vatx, commitatx, vrfnonce) + atxEqual(tb, expAtx[id], vatx, commitAtx, vrfNonce) } else { extra = append(extra, vatx) } @@ -108,7 +109,7 @@ func verifyDbContent(tb testing.TB, db *sql.Database) { cacct, ok := expAcct[acct.Address] require.True(tb, ok) require.NotNil(tb, acct) - accountequal(tb, cacct, acct) + accountEqual(tb, cacct, acct) require.EqualValues(tb, recoverLayer-1, acct.Layer) } require.Empty(tb, extra) @@ -118,7 +119,7 @@ func TestRecover(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, http.MethodGet, r.Method) w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte(checkpointdata)) + _, err := w.Write([]byte(checkpointData)) require.NoError(t, err) })) defer ts.Close() @@ -157,7 +158,7 @@ func TestRecover(t *testing.T) { DbFile: "test.sql", LocalDbFile: "local.sql", PreserveOwnAtx: true, - NodeID: types.NodeID{2, 3, 4}, + NodeIDs: []types.NodeID{types.RandomNodeID()}, Uri: tc.uri, Restore: types.LayerID(recoverLayer), } @@ -172,12 +173,12 @@ func TestRecover(t *testing.T) { } require.NoError(t, err) require.Nil(t, preserve) - newdb, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) - require.NotNil(t, newdb) - defer newdb.Close() - verifyDbContent(t, newdb) - restore, err := recovery.CheckpointInfo(newdb) + require.NotNil(t, newDB) + defer newDB.Close() + verifyDbContent(t, newDB) + restore, err := recovery.CheckpointInfo(newDB) require.NoError(t, err) require.EqualValues(t, recoverLayer, restore) exist, err := afero.Exists(fs, bsdir) @@ -191,7 +192,7 @@ func TestRecover_SameRecoveryInfo(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, http.MethodGet, r.Method) w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte(checkpointdata)) + _, err := w.Write([]byte(checkpointData)) require.NoError(t, err) })) defer ts.Close() @@ -204,7 +205,7 @@ func TestRecover_SameRecoveryInfo(t *testing.T) { DataDir: t.TempDir(), DbFile: "test.sql", PreserveOwnAtx: true, - NodeID: types.NodeID{2, 3, 4}, + NodeIDs: []types.NodeID{types.RandomNodeID()}, Uri: fmt.Sprintf("%s/snapshot-15", ts.URL), Restore: types.LayerID(recoverLayer), } @@ -280,7 +281,9 @@ func validateAndPreserveData( mtrtl.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) require.NoError(tb, atxHandler.HandleSyncedAtx(context.Background(), vatx.ID().Hash32(), "self", encoded)) err = poetDb.ValidateAndStore(context.Background(), proofs[i]) - require.ErrorContains(tb, err, "failed to validate poet proof for poetID 706f65745f round 1337") + require.ErrorContains(tb, err, fmt.Sprintf("failed to validate poet proof for poetID %s round 1337", + hex.EncodeToString(proofs[i].PoetServiceID[:5])), + ) } } @@ -289,7 +292,7 @@ func newChainedAtx( prev, pos types.ATXID, commitAtx *types.ATXID, epoch uint32, - seq, vrfnonce uint64, + seq, vrfNonce uint64, sig *signing.EdSigner, ) *types.VerifiedActivationTx { atx := &types.ActivationTx{ @@ -315,68 +318,76 @@ func newChainedAtx( nodeID := sig.NodeID() atx.NodeID = &nodeID } - if vrfnonce != 0 { - atx.VRFNonce = (*types.VRFPostIndex)(&vrfnonce) + if vrfNonce != 0 { + atx.VRFNonce = (*types.VRFPostIndex)(&vrfNonce) } atx.SmesherID = sig.NodeID() atx.SetEffectiveNumUnits(atx.NumUnits) atx.SetReceived(time.Now().Local()) atx.Signature = sig.Sign(signing.ATX, atx.SignedBytes()) - return newvatx(tb, atx) + return newvAtx(tb, atx) } -func createAtxChain(tb testing.TB, sig *signing.EdSigner) ([]*types.VerifiedActivationTx, []*types.PoetProofMessage) { - other, err := signing.NewEdSigner() - require.NoError(tb, err) +func createInterlinkedAtxChain( + tb testing.TB, + sig1 *signing.EdSigner, + sig2 *signing.EdSigner, +) ([]*types.VerifiedActivationTx, []*types.PoetProofMessage) { // epoch 2 - othAtx1 := newChainedAtx(tb, types.EmptyATXID, goldenAtx, &goldenAtx, 2, 0, 113, other) + sig1Atx1 := newChainedAtx(tb, types.EmptyATXID, goldenAtx, &goldenAtx, 2, 0, 113, sig1) // epoch 3 - othAtx2 := newChainedAtx(tb, othAtx1.ID(), othAtx1.ID(), nil, 3, 1, 0, other) + sig1Atx2 := newChainedAtx(tb, sig1Atx1.ID(), sig1Atx1.ID(), nil, 3, 1, 0, sig1) // epoch 4 - othAtx3 := newChainedAtx(tb, othAtx2.ID(), othAtx2.ID(), nil, 4, 2, 0, other) - commitAtxID := othAtx2.ID() - atx1 := newChainedAtx(tb, types.EmptyATXID, othAtx2.ID(), &commitAtxID, 4, 0, 513, sig) + sig1Atx3 := newChainedAtx(tb, sig1Atx2.ID(), sig1Atx2.ID(), nil, 4, 2, 0, sig1) + commitAtxID := sig1Atx2.ID() + sig2Atx1 := newChainedAtx(tb, types.EmptyATXID, sig1Atx2.ID(), &commitAtxID, 4, 0, 513, sig2) // epoch 5 - othAtx4 := newChainedAtx(tb, othAtx3.ID(), atx1.ID(), nil, 5, 3, 0, other) + sig1Atx4 := newChainedAtx(tb, sig1Atx3.ID(), sig2Atx1.ID(), nil, 5, 3, 0, sig1) // epoch 6 - othAtx5 := newChainedAtx(tb, othAtx4.ID(), othAtx4.ID(), nil, 6, 4, 0, other) - atx2 := newChainedAtx(tb, atx1.ID(), othAtx4.ID(), nil, 6, 1, 0, sig) + sig1Atx5 := newChainedAtx(tb, sig1Atx4.ID(), sig1Atx4.ID(), nil, 6, 4, 0, sig1) + sig2Atx2 := newChainedAtx(tb, sig2Atx1.ID(), sig1Atx4.ID(), nil, 6, 1, 0, sig2) // epoch 7 - othAtx6 := newChainedAtx(tb, othAtx5.ID(), atx2.ID(), nil, 7, 5, 0, other) + sig1Atx6 := newChainedAtx(tb, sig1Atx5.ID(), sig2Atx2.ID(), nil, 7, 5, 0, sig1) // epoch 8 - atx3 := newChainedAtx(tb, atx2.ID(), othAtx6.ID(), nil, 8, 2, 0, sig) + sig2Atx3 := newChainedAtx(tb, sig2Atx2.ID(), sig1Atx6.ID(), nil, 8, 2, 0, sig2) // epoch 9 - othAtx7 := newChainedAtx(tb, othAtx6.ID(), atx3.ID(), nil, 9, 6, 0, other) - - vatxs := []*types.VerifiedActivationTx{ - othAtx1, - othAtx2, - othAtx3, - atx1, - othAtx4, - othAtx5, - atx2, - othAtx6, - atx3, - othAtx7, + sig1Atx7 := newChainedAtx(tb, sig1Atx6.ID(), sig2Atx3.ID(), nil, 9, 6, 0, sig1) + + vAtxs := []*types.VerifiedActivationTx{ + sig1Atx1, + sig1Atx2, + sig1Atx3, + sig2Atx1, + sig1Atx4, + sig1Atx5, + sig2Atx2, + sig1Atx6, + sig2Atx3, + sig1Atx7, } var proofs []*types.PoetProofMessage - for range vatxs { + for range vAtxs { proofMessage := &types.PoetProofMessage{ PoetProof: types.PoetProof{ MerkleProof: shared.MerkleProof{ - Root: []byte{1, 2, 3}, + Root: types.RandomBytes(32), ProvenLeaves: [][]byte{{1}, {2}}, ProofNodes: [][]byte{{1}, {2}}, }, LeafCount: 1234, }, - PoetServiceID: []byte("poet_id_123456"), + PoetServiceID: types.RandomBytes(32), RoundID: "1337", } proofs = append(proofs, proofMessage) } - return vatxs, proofs + return vAtxs, proofs +} + +func createAtxChain(tb testing.TB, sig *signing.EdSigner) ([]*types.VerifiedActivationTx, []*types.PoetProofMessage) { + other, err := signing.NewEdSigner() + require.NoError(tb, err) + return createInterlinkedAtxChain(tb, other, sig) } func createAtxChainDepsOnly(tb testing.TB) ([]*types.VerifiedActivationTx, []*types.PoetProofMessage) { @@ -388,9 +399,9 @@ func createAtxChainDepsOnly(tb testing.TB) ([]*types.VerifiedActivationTx, []*ty othAtx2 := newChainedAtx(tb, othAtx1.ID(), othAtx1.ID(), nil, 3, 1, 0, other) // epoch 4 othAtx3 := newChainedAtx(tb, othAtx2.ID(), othAtx2.ID(), nil, 4, 2, 0, other) - vatxs := []*types.VerifiedActivationTx{othAtx1, othAtx2, othAtx3} + vAtxs := []*types.VerifiedActivationTx{othAtx1, othAtx2, othAtx3} var proofs []*types.PoetProofMessage - for range vatxs { + for range vAtxs { proofMessage := &types.PoetProofMessage{ PoetProof: types.PoetProof{ MerkleProof: shared.MerkleProof{ @@ -405,7 +416,7 @@ func createAtxChainDepsOnly(tb testing.TB) ([]*types.VerifiedActivationTx, []*ty } proofs = append(proofs, proofMessage) } - return vatxs, proofs + return vAtxs, proofs } func atxIDs(atxs []*types.VerifiedActivationTx) []types.ATXID { @@ -416,69 +427,96 @@ func atxIDs(atxs []*types.VerifiedActivationTx) []types.ATXID { return ids } +func proofRefs(proofs []*types.PoetProofMessage) []types.PoetProofRef { + refs := make([]types.PoetProofRef, 0, len(proofs)) + for _, proof := range proofs { + ref, _ := proof.Ref() + refs = append(refs, ref) + } + return refs +} + func TestRecover_OwnAtxNotInCheckpoint_Preserve(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, http.MethodGet, r.Method) w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte(checkpointdata)) + _, err := w.Write([]byte(checkpointData)) require.NoError(t, err) })) defer ts.Close() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - sig, err := signing.NewEdSigner() + sig1, err := signing.NewEdSigner() + require.NoError(t, err) + sig2, err := signing.NewEdSigner() + require.NoError(t, err) + sig3, err := signing.NewEdSigner() require.NoError(t, err) + sig4, err := signing.NewEdSigner() + require.NoError(t, err) + cfg := &checkpoint.RecoverConfig{ GoldenAtx: goldenAtx, DataDir: t.TempDir(), DbFile: "test.sql", LocalDbFile: "local.sql", PreserveOwnAtx: true, - NodeID: sig.NodeID(), + NodeIDs: []types.NodeID{sig1.NodeID(), sig2.NodeID(), sig3.NodeID(), sig4.NodeID()}, Uri: fmt.Sprintf("%s/snapshot-15", ts.URL), Restore: types.LayerID(recoverLayer), } - olddb, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) - require.NotNil(t, olddb) + require.NotNil(t, oldDB) - vatxs, proofs := createAtxChain(t, sig) - validateAndPreserveData(t, olddb, vatxs, proofs) + vAtxs1, proofs1 := createAtxChain(t, sig1) + vAtxs2, proofs2 := createAtxChain(t, sig2) + vAtxs := append(vAtxs1, vAtxs2...) + proofs := append(proofs1, proofs2...) + vAtxs3, proofs3 := createInterlinkedAtxChain(t, sig3, sig4) + vAtxs = append(vAtxs, vAtxs3...) + proofs = append(proofs, proofs3...) + validateAndPreserveData(t, oldDB, vAtxs, proofs) // the proofs are not valid, but save them anyway for the purpose of testing - for i, vatx := range vatxs { + for i, vatx := range vAtxs { encoded, err := codec.Encode(proofs[i]) require.NoError(t, err) - require.NoError( - t, - poets.Add( - olddb, - types.PoetProofRef(vatx.GetPoetProofRef()), - encoded, - proofs[i].PoetServiceID, - proofs[i].RoundID, - ), + + err = poets.Add( + oldDB, + types.PoetProofRef(vatx.GetPoetProofRef()), + encoded, + proofs[i].PoetServiceID, + proofs[i].RoundID, ) + require.NoError(t, err) } - require.NoError(t, olddb.Close()) + require.NoError(t, oldDB.Close()) preserve, err := checkpoint.Recover(ctx, logtest.New(t), afero.NewOsFs(), cfg) require.NoError(t, err) require.NotNil(t, preserve) - // the two set of atxs have different received time. just compare IDs - require.ElementsMatch(t, atxIDs(vatxs[:len(vatxs)-1]), atxIDs(preserve.Deps)) - require.ElementsMatch(t, proofs[:len(proofs)-1], preserve.Proofs) - newdb, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + // the last atx of single chains is not included in the checkpoint because it is not part of sig1 and sig2's + // atx chains. atxs have different timestamps for received time, so we compare just the IDs + atxRef := atxIDs(append(vAtxs1[:len(vAtxs1)-1], vAtxs2[:len(vAtxs2)-1]...)) + atxRef = append(atxRef, atxIDs(vAtxs3)...) + proofRef := proofRefs(append(proofs1[:len(vAtxs1)-1], proofs2[:len(vAtxs2)-1]...)) + proofRef = append(proofRef, proofRefs(proofs3)...) + require.ElementsMatch(t, atxRef, atxIDs(preserve.Deps)) + require.ElementsMatch(t, proofRef, proofRefs(preserve.Proofs)) + + newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) - require.NotNil(t, newdb) - t.Cleanup(func() { require.NoError(t, newdb.Close()) }) - verifyDbContent(t, newdb) - validateAndPreserveData(t, newdb, preserve.Deps, preserve.Proofs) - // note that poet proofs are not saved to newdb due to verification errors + require.NotNil(t, newDB) + t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) + verifyDbContent(t, newDB) + validateAndPreserveData(t, newDB, preserve.Deps, preserve.Proofs) + // note that poet proofs are not saved to newDB due to verification errors - restore, err := recovery.CheckpointInfo(newdb) + restore, err := recovery.CheckpointInfo(newDB) require.NoError(t, err) require.EqualValues(t, recoverLayer, restore) @@ -492,40 +530,48 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_IncludePending(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, http.MethodGet, r.Method) w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte(checkpointdata)) + _, err := w.Write([]byte(checkpointData)) require.NoError(t, err) })) defer ts.Close() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - sig, err := signing.NewEdSigner() + sig1, err := signing.NewEdSigner() + require.NoError(t, err) + sig2, err := signing.NewEdSigner() require.NoError(t, err) + sig3, err := signing.NewEdSigner() + require.NoError(t, err) + cfg := &checkpoint.RecoverConfig{ GoldenAtx: goldenAtx, DataDir: t.TempDir(), DbFile: "test.sql", LocalDbFile: "local.sql", PreserveOwnAtx: true, - NodeID: sig.NodeID(), + NodeIDs: []types.NodeID{sig1.NodeID(), sig2.NodeID(), sig3.NodeID()}, Uri: fmt.Sprintf("%s/snapshot-15", ts.URL), Restore: types.LayerID(recoverLayer), } - olddb, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) - require.NotNil(t, olddb) + require.NotNil(t, oldDB) - vatxs, proofs := createAtxChain(t, sig) - validateAndPreserveData(t, olddb, vatxs, proofs) + vAtxs1, proofs1 := createAtxChain(t, sig1) + vAtxs2, proofs2 := createInterlinkedAtxChain(t, sig2, sig3) + vAtxs := append(vAtxs1, vAtxs2...) + proofs := append(proofs1, proofs2...) + validateAndPreserveData(t, oldDB, vAtxs, proofs) // the proofs are not valid, but save them anyway for the purpose of testing - for i, vatx := range vatxs { + for i, vatx := range vAtxs { encoded, err := codec.Encode(proofs[i]) require.NoError(t, err) require.NoError( t, poets.Add( - olddb, + oldDB, types.PoetProofRef(vatx.GetPoetProofRef()), encoded, proofs[i].PoetServiceID, @@ -533,21 +579,32 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_IncludePending(t *testing.T) { ), ) } - require.NoError(t, olddb.Close()) + require.NoError(t, oldDB.Close()) // write pending nipost challenge to simulate a pending atx still waiting for poet proof. - prevAtx := vatxs[len(vatxs)-2] - posAtx := vatxs[len(vatxs)-1] + prevAtx1 := vAtxs1[len(vAtxs1)-2] + posAtx1 := vAtxs1[len(vAtxs1)-1] + + prevAtx2 := vAtxs2[len(vAtxs2)-2] + posAtx2 := vAtxs2[len(vAtxs2)-1] localDB, err := localsql.Open("file:" + filepath.Join(cfg.DataDir, cfg.LocalDbFile)) require.NoError(t, err) require.NotNil(t, localDB) - err = nipost.AddChallenge(localDB, sig.NodeID(), &types.NIPostChallenge{ - PublishEpoch: posAtx.PublishEpoch + 1, - Sequence: prevAtx.Sequence + 1, - PrevATXID: prevAtx.ID(), - PositioningATX: posAtx.ID(), + err = nipost.AddChallenge(localDB, sig1.NodeID(), &types.NIPostChallenge{ + PublishEpoch: posAtx1.PublishEpoch + 1, + Sequence: prevAtx1.Sequence + 1, + PrevATXID: prevAtx1.ID(), + PositioningATX: posAtx1.ID(), + }) + require.NoError(t, err) + + err = nipost.AddChallenge(localDB, sig2.NodeID(), &types.NIPostChallenge{ + PublishEpoch: posAtx2.PublishEpoch + 1, + Sequence: prevAtx2.Sequence + 1, + PrevATXID: prevAtx2.ID(), + PositioningATX: posAtx2.ID(), }) require.NoError(t, err) require.NoError(t, localDB.Close()) @@ -555,19 +612,22 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_IncludePending(t *testing.T) { preserve, err := checkpoint.Recover(ctx, logtest.New(t), afero.NewOsFs(), cfg) require.NoError(t, err) require.NotNil(t, preserve) + // the two set of atxs have different received time. just compare IDs - require.ElementsMatch(t, atxIDs(vatxs), atxIDs(preserve.Deps)) - require.ElementsMatch(t, proofs, preserve.Proofs) + atxRef := atxIDs(append(vAtxs1, vAtxs2...)) + proofRef := proofRefs(append(proofs1, proofs2...)) + require.ElementsMatch(t, atxRef, atxIDs(preserve.Deps)) + require.ElementsMatch(t, proofRef, proofRefs(preserve.Proofs)) - newdb, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) - require.NotNil(t, newdb) - t.Cleanup(func() { require.NoError(t, newdb.Close()) }) - verifyDbContent(t, newdb) - validateAndPreserveData(t, newdb, preserve.Deps, preserve.Proofs) - // note that poet proofs are not saved to newdb due to verification errors + require.NotNil(t, newDB) + t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) + verifyDbContent(t, newDB) + validateAndPreserveData(t, newDB, preserve.Deps, preserve.Proofs) + // note that poet proofs are not saved to newDB due to verification errors - restore, err := recovery.CheckpointInfo(newdb) + restore, err := recovery.CheckpointInfo(newDB) require.NoError(t, err) require.EqualValues(t, recoverLayer, restore) @@ -581,40 +641,43 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_Still_Initializing(t *testing.T) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, http.MethodGet, r.Method) w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte(checkpointdata)) + _, err := w.Write([]byte(checkpointData)) require.NoError(t, err) })) defer ts.Close() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - sig, err := signing.NewEdSigner() + sig1, err := signing.NewEdSigner() + require.NoError(t, err) + sig2, err := signing.NewEdSigner() require.NoError(t, err) + cfg := &checkpoint.RecoverConfig{ GoldenAtx: goldenAtx, DataDir: t.TempDir(), DbFile: "test.sql", LocalDbFile: "local.sql", PreserveOwnAtx: true, - NodeID: sig.NodeID(), + NodeIDs: []types.NodeID{sig1.NodeID(), sig2.NodeID()}, Uri: fmt.Sprintf("%s/snapshot-15", ts.URL), Restore: types.LayerID(recoverLayer), } - olddb, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) - require.NotNil(t, olddb) + require.NotNil(t, oldDB) - vatxs, proofs := createAtxChainDepsOnly(t) - validateAndPreserveData(t, olddb, vatxs, proofs) + vAtxs, proofs := createAtxChainDepsOnly(t) + validateAndPreserveData(t, oldDB, vAtxs, proofs) // the proofs are not valid, but save them anyway for the purpose of testing - for i, vatx := range vatxs { + for i, vatx := range vAtxs { encoded, err := codec.Encode(proofs[i]) require.NoError(t, err) require.NoError( t, poets.Add( - olddb, + oldDB, types.PoetProofRef(vatx.GetPoetProofRef()), encoded, proofs[i].PoetServiceID, @@ -623,7 +686,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_Still_Initializing(t *testing.T) ) } - require.NoError(t, olddb.Close()) + require.NoError(t, oldDB.Close()) localDB, err := localsql.Open("file:" + filepath.Join(cfg.DataDir, cfg.LocalDbFile)) require.NoError(t, err) @@ -633,7 +696,17 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_Still_Initializing(t *testing.T) Indices: []byte{1, 2, 3}, } commitmentAtx := types.RandomATXID() - err = nipost.AddChallenge(localDB, sig.NodeID(), &types.NIPostChallenge{ + err = nipost.AddChallenge(localDB, sig1.NodeID(), &types.NIPostChallenge{ + PublishEpoch: 0, // will be updated later + Sequence: 0, + PrevATXID: types.EmptyATXID, // initial has no previous ATX + PositioningATX: types.EmptyATXID, // will be updated later + InitialPost: &post, + CommitmentATX: &commitmentAtx, + }) + require.NoError(t, err) + + err = nipost.AddChallenge(localDB, sig2.NodeID(), &types.NIPostChallenge{ PublishEpoch: 0, // will be updated later Sequence: 0, PrevATXID: types.EmptyATXID, // initial has no previous ATX @@ -648,12 +721,12 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_Still_Initializing(t *testing.T) require.NoError(t, err) require.Nil(t, preserve) - newdb, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) - require.NotNil(t, newdb) - t.Cleanup(func() { require.NoError(t, newdb.Close()) }) - verifyDbContent(t, newdb) - restore, err := recovery.CheckpointInfo(newdb) + require.NotNil(t, newDB) + t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) + verifyDbContent(t, newDB) + restore, err := recovery.CheckpointInfo(newDB) require.NoError(t, err) require.EqualValues(t, recoverLayer, restore) @@ -667,7 +740,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_DepIsGolden(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, http.MethodGet, r.Method) w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte(checkpointdata)) + _, err := w.Write([]byte(checkpointData)) require.NoError(t, err) })) defer ts.Close() @@ -682,18 +755,18 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_DepIsGolden(t *testing.T) { DbFile: "test.sql", LocalDbFile: "local.sql", PreserveOwnAtx: true, - NodeID: sig.NodeID(), + NodeIDs: []types.NodeID{sig.NodeID()}, Uri: fmt.Sprintf("%s/snapshot-15", ts.URL), Restore: types.LayerID(recoverLayer), } - olddb, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) - require.NotNil(t, olddb) - vatxs, proofs := createAtxChain(t, sig) + require.NotNil(t, oldDB) + vAtxs, proofs := createAtxChain(t, sig) // make the first one from the previous snapshot - golden := vatxs[0] - require.NoError(t, atxs.AddCheckpointed(olddb, &atxs.CheckpointAtx{ + golden := vAtxs[0] + require.NoError(t, atxs.AddCheckpointed(oldDB, &atxs.CheckpointAtx{ ID: golden.ID(), Epoch: golden.PublishEpoch, CommitmentATX: *golden.CommitmentATX, @@ -705,7 +778,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_DepIsGolden(t *testing.T) { Sequence: golden.Sequence, Coinbase: golden.Coinbase, })) - validateAndPreserveData(t, olddb, vatxs[1:], proofs[1:]) + validateAndPreserveData(t, oldDB, vAtxs[1:], proofs[1:]) // the proofs are not valid, but save them anyway for the purpose of testing for i, proof := range proofs { if i == 0 { @@ -716,26 +789,26 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_DepIsGolden(t *testing.T) { require.NoError( t, poets.Add( - olddb, - types.PoetProofRef(vatxs[i].GetPoetProofRef()), + oldDB, + types.PoetProofRef(vAtxs[i].GetPoetProofRef()), encoded, proof.PoetServiceID, proof.RoundID, ), ) } - require.NoError(t, olddb.Close()) + require.NoError(t, oldDB.Close()) preserve, err := checkpoint.Recover(ctx, logtest.New(t), afero.NewOsFs(), cfg) require.NoError(t, err) require.Nil(t, preserve) - newdb, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) - require.NotNil(t, newdb) - t.Cleanup(func() { require.NoError(t, newdb.Close()) }) - verifyDbContent(t, newdb) - restore, err := recovery.CheckpointInfo(newdb) + require.NotNil(t, newDB) + t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) + verifyDbContent(t, newDB) + restore, err := recovery.CheckpointInfo(newDB) require.NoError(t, err) require.EqualValues(t, recoverLayer, restore) @@ -749,7 +822,7 @@ func TestRecover_OwnAtxNotInCheckpoint_DontPreserve(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, http.MethodGet, r.Method) w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte(checkpointdata)) + _, err := w.Write([]byte(checkpointData)) require.NoError(t, err) })) defer ts.Close() @@ -764,24 +837,24 @@ func TestRecover_OwnAtxNotInCheckpoint_DontPreserve(t *testing.T) { DbFile: "test.sql", LocalDbFile: "local.sql", PreserveOwnAtx: false, - NodeID: sig.NodeID(), + NodeIDs: []types.NodeID{sig.NodeID()}, Uri: fmt.Sprintf("%s/snapshot-15", ts.URL), Restore: types.LayerID(recoverLayer), } - olddb, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) - require.NotNil(t, olddb) - vatxs, proofs := createAtxChain(t, sig) - validateAndPreserveData(t, olddb, vatxs, proofs) + require.NotNil(t, oldDB) + vAtxs, proofs := createAtxChain(t, sig) + validateAndPreserveData(t, oldDB, vAtxs, proofs) // the proofs are not valid, but save them anyway for the purpose of testing - for i, vatx := range vatxs { + for i, vatx := range vAtxs { encoded, err := codec.Encode(proofs[i]) require.NoError(t, err) require.NoError( t, poets.Add( - olddb, + oldDB, types.PoetProofRef(vatx.GetPoetProofRef()), encoded, proofs[i].PoetServiceID, @@ -789,18 +862,18 @@ func TestRecover_OwnAtxNotInCheckpoint_DontPreserve(t *testing.T) { ), ) } - require.NoError(t, olddb.Close()) + require.NoError(t, oldDB.Close()) preserve, err := checkpoint.Recover(ctx, logtest.New(t), afero.NewOsFs(), cfg) require.NoError(t, err) require.Nil(t, preserve) - newdb, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) - require.NotNil(t, newdb) - t.Cleanup(func() { require.NoError(t, newdb.Close()) }) - verifyDbContent(t, newdb) - restore, err := recovery.CheckpointInfo(newdb) + require.NotNil(t, newDB) + t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) + verifyDbContent(t, newDB) + restore, err := recovery.CheckpointInfo(newDB) require.NoError(t, err) require.EqualValues(t, recoverLayer, restore) @@ -814,7 +887,7 @@ func TestRecover_OwnAtxInCheckpoint(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, http.MethodGet, r.Method) w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte(checkpointdata)) + _, err := w.Write([]byte(checkpointData)) require.NoError(t, err) })) defer ts.Close() @@ -826,7 +899,7 @@ func TestRecover_OwnAtxInCheckpoint(t *testing.T) { nid := types.BytesToNodeID(data) data, err = hex.DecodeString("98e47278c1f58acfd2b670a730f28898f74eb140482a07b91ff81f9ff0b7d9f4") require.NoError(t, err) - atx := newatx(types.ATXID(types.BytesToHash(data)), nil, 3, 1, 0, nid) + atx := newAtx(types.ATXID(types.BytesToHash(data)), nil, 3, 1, 0, nid) cfg := &checkpoint.RecoverConfig{ GoldenAtx: goldenAtx, @@ -834,27 +907,27 @@ func TestRecover_OwnAtxInCheckpoint(t *testing.T) { DbFile: "test.sql", LocalDbFile: "local.sql", PreserveOwnAtx: true, - NodeID: nid, + NodeIDs: []types.NodeID{nid}, Uri: fmt.Sprintf("%s/snapshot-15", ts.URL), Restore: types.LayerID(recoverLayer), } - olddb, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) - require.NotNil(t, olddb) - require.NoError(t, atxs.Add(olddb, newvatx(t, atx))) - require.NoError(t, olddb.Close()) + require.NotNil(t, oldDB) + require.NoError(t, atxs.Add(oldDB, newvAtx(t, atx))) + require.NoError(t, oldDB.Close()) preserve, err := checkpoint.Recover(ctx, logtest.New(t), afero.NewOsFs(), cfg) require.NoError(t, err) require.Nil(t, preserve) - newdb, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) - require.NotNil(t, newdb) - t.Cleanup(func() { require.NoError(t, newdb.Close()) }) - verifyDbContent(t, newdb) - restore, err := recovery.CheckpointInfo(newdb) + require.NotNil(t, newDB) + t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) + verifyDbContent(t, newDB) + restore, err := recovery.CheckpointInfo(newDB) require.NoError(t, err) require.EqualValues(t, recoverLayer, restore) diff --git a/checkpoint/runner_test.go b/checkpoint/runner_test.go index 98003f0be8..600493072c 100644 --- a/checkpoint/runner_test.go +++ b/checkpoint/runner_test.go @@ -32,29 +32,29 @@ func TestMain(m *testing.M) { var allAtxs = map[types.NodeID][]*types.ActivationTx{ // smesher 1 has 7 ATXs, one in each epoch from 1 to 7 types.BytesToNodeID([]byte("smesher1")): { - newatx(types.ATXID{17}, nil, 7, 6, 0, types.BytesToNodeID([]byte("smesher1"))), - newatx(types.ATXID{16}, nil, 6, 5, 0, types.BytesToNodeID([]byte("smesher1"))), - newatx(types.ATXID{15}, nil, 5, 4, 0, types.BytesToNodeID([]byte("smesher1"))), - newatx(types.ATXID{14}, nil, 4, 3, 0, types.BytesToNodeID([]byte("smesher1"))), - newatx(types.ATXID{13}, nil, 3, 2, 0, types.BytesToNodeID([]byte("smesher1"))), - newatx(types.ATXID{12}, nil, 2, 1, 0, types.BytesToNodeID([]byte("smesher1"))), - newatx(types.ATXID{11}, &types.ATXID{1}, 1, 0, 123, types.BytesToNodeID([]byte("smesher1"))), + newAtx(types.ATXID{17}, nil, 7, 6, 0, types.BytesToNodeID([]byte("smesher1"))), + newAtx(types.ATXID{16}, nil, 6, 5, 0, types.BytesToNodeID([]byte("smesher1"))), + newAtx(types.ATXID{15}, nil, 5, 4, 0, types.BytesToNodeID([]byte("smesher1"))), + newAtx(types.ATXID{14}, nil, 4, 3, 0, types.BytesToNodeID([]byte("smesher1"))), + newAtx(types.ATXID{13}, nil, 3, 2, 0, types.BytesToNodeID([]byte("smesher1"))), + newAtx(types.ATXID{12}, nil, 2, 1, 0, types.BytesToNodeID([]byte("smesher1"))), + newAtx(types.ATXID{11}, &types.ATXID{1}, 1, 0, 123, types.BytesToNodeID([]byte("smesher1"))), }, // smesher 2 has 1 ATX in epoch 7 types.BytesToNodeID([]byte("smesher2")): { - newatx(types.ATXID{27}, &types.ATXID{2}, 7, 0, 152, types.BytesToNodeID([]byte("smesher2"))), + newAtx(types.ATXID{27}, &types.ATXID{2}, 7, 0, 152, types.BytesToNodeID([]byte("smesher2"))), }, // smesher 3 has 1 ATX in epoch 2 types.BytesToNodeID([]byte("smesher3")): { - newatx(types.ATXID{32}, &types.ATXID{3}, 2, 0, 211, types.BytesToNodeID([]byte("smesher3"))), + newAtx(types.ATXID{32}, &types.ATXID{3}, 2, 0, 211, types.BytesToNodeID([]byte("smesher3"))), }, // smesher 4 has 1 ATX in epoch 3 and one in epoch 7 types.BytesToNodeID([]byte("smesher4")): { - newatx(types.ATXID{47}, nil, 7, 1, 0, types.BytesToNodeID([]byte("smesher4"))), - newatx(types.ATXID{43}, &types.ATXID{4}, 4, 0, 420, types.BytesToNodeID([]byte("smesher4"))), + newAtx(types.ATXID{47}, nil, 7, 1, 0, types.BytesToNodeID([]byte("smesher4"))), + newAtx(types.ATXID{43}, &types.ATXID{4}, 4, 0, 420, types.BytesToNodeID([]byte("smesher4"))), }, } @@ -155,7 +155,7 @@ func expectedCheckpoint(t *testing.T, snapshot types.LayerID, numAtxs int) *type for i := 0; i < n; i++ { atxData = append( atxData, - toShortAtx(newvatx(t, atxs[i]), atxs[len(atxs)-1].CommitmentATX, atxs[len(atxs)-1].VRFNonce), + toShortAtx(newvAtx(t, atxs[i]), atxs[len(atxs)-1].CommitmentATX, atxs[len(atxs)-1].VRFNonce), ) } } @@ -188,7 +188,7 @@ func expectedCheckpoint(t *testing.T, snapshot types.LayerID, numAtxs int) *type return result } -func newatx( +func newAtx( id types.ATXID, commitAtx *types.ATXID, epoch uint32, @@ -221,7 +221,7 @@ func newatx( return atx } -func newvatx(tb testing.TB, atx *types.ActivationTx) *types.VerifiedActivationTx { +func newvAtx(tb testing.TB, atx *types.ActivationTx) *types.VerifiedActivationTx { vatx, err := atx.Verify(1111, 12) require.NoError(tb, err) return vatx @@ -245,7 +245,7 @@ func toShortAtx(v *types.VerifiedActivationTx, cmt *types.ATXID, nonce *types.VR func createMesh(t *testing.T, db *sql.Database, miners map[types.NodeID][]*types.ActivationTx, accts []*types.Account) { for _, vatxs := range miners { for _, atx := range vatxs { - require.NoError(t, atxs.Add(db, newvatx(t, atx))) + require.NoError(t, atxs.Add(db, newvAtx(t, atx))) } } @@ -255,8 +255,8 @@ func createMesh(t *testing.T, db *sql.Database, miners map[types.NodeID][]*types // smesher 5 is malicious and equivocated in epoch 7 bad := types.BytesToNodeID([]byte("smesher5")) - require.NoError(t, atxs.Add(db, newvatx(t, newatx(types.ATXID{83}, &types.ATXID{27}, 7, 0, 113, bad)))) - require.NoError(t, atxs.Add(db, newvatx(t, newatx(types.ATXID{97}, &types.ATXID{16}, 7, 0, 113, bad)))) + require.NoError(t, atxs.Add(db, newvAtx(t, newAtx(types.ATXID{83}, &types.ATXID{27}, 7, 0, 113, bad)))) + require.NoError(t, atxs.Add(db, newvAtx(t, newAtx(types.ATXID{97}, &types.ATXID{16}, 7, 0, 113, bad)))) require.NoError(t, identities.SetMalicious(db, bad, []byte("bad"), time.Now())) } @@ -360,9 +360,9 @@ func TestRunner_Generate_Error(t *testing.T) { snapshot := types.LayerID(5) var atx *types.ActivationTx if tc.missingCommitment { - atx = newatx(types.ATXID{13}, nil, 2, 1, 11, types.BytesToNodeID([]byte("smesher1"))) + atx = newAtx(types.ATXID{13}, nil, 2, 1, 11, types.BytesToNodeID([]byte("smesher1"))) } else if tc.missingVrf { - atx = newatx(types.ATXID{13}, &types.ATXID{11}, 2, 1, 0, types.BytesToNodeID([]byte("smesher1"))) + atx = newAtx(types.ATXID{13}, &types.ATXID{11}, 2, 1, 0, types.BytesToNodeID([]byte("smesher1"))) } createMesh(t, db, map[types.NodeID][]*types.ActivationTx{ types.BytesToNodeID([]byte("smesher1")): {atx}, diff --git a/checkpoint/util.go b/checkpoint/util.go index 4cf27f4715..a7b03d1bb9 100644 --- a/checkpoint/util.go +++ b/checkpoint/util.go @@ -7,9 +7,9 @@ import ( "errors" "fmt" "io" + "io/fs" "net/http" "net/url" - "os" "path/filepath" "time" @@ -28,15 +28,15 @@ type RecoveryFile struct { path string } -func NewRecoveryFile(fs afero.Fs, path string) (*RecoveryFile, error) { - if err := fs.MkdirAll(filepath.Dir(path), dirPerm); err != nil { +func NewRecoveryFile(aferoFs afero.Fs, path string) (*RecoveryFile, error) { + if err := aferoFs.MkdirAll(filepath.Dir(path), dirPerm); err != nil { return nil, fmt.Errorf("create dst dir %v: %w", filepath.Dir(path), err) } - f, _ := fs.Stat(path) + f, _ := aferoFs.Stat(path) if f != nil { - return nil, fmt.Errorf("%w: file already exist: %v", os.ErrExist, path) + return nil, fmt.Errorf("%w: file already exist: %v", fs.ErrExist, path) } - tmpf, err := afero.TempFile(fs, filepath.Dir(path), filepath.Base(path)) + tmpf, err := afero.TempFile(aferoFs, filepath.Dir(path), filepath.Base(path)) if err != nil { return nil, fmt.Errorf("%w: create tmp file", err) } diff --git a/checkpoint/util_test.go b/checkpoint/util_test.go index cc171127f8..579ddf44b8 100644 --- a/checkpoint/util_test.go +++ b/checkpoint/util_test.go @@ -3,7 +3,7 @@ package checkpoint_test import ( "bytes" _ "embed" - "os" + "io/fs" "path/filepath" "testing" @@ -14,7 +14,7 @@ import ( ) //go:embed checkpointdata.json -var checkpointdata string +var checkpointData string func TestValidateSchema(t *testing.T) { tcs := []struct { @@ -24,7 +24,7 @@ func TestValidateSchema(t *testing.T) { }{ { desc: "valid", - data: checkpointdata, + data: checkpointData, }, { desc: "missing atx", @@ -113,7 +113,7 @@ func TestRecoveryFile(t *testing.T) { return } require.NoError(t, err) - require.NoError(t, rf.Copy(fs, bytes.NewReader([]byte(checkpointdata)))) + require.NoError(t, rf.Copy(fs, bytes.NewReader([]byte(checkpointData)))) require.NoError(t, rf.Save(fs)) }) } @@ -127,19 +127,19 @@ func TestRecoveryFile_Copy(t *testing.T) { } func TestCopyFile(t *testing.T) { - fs := afero.NewMemMapFs() + aferoFS := afero.NewMemMapFs() dir := t.TempDir() src := filepath.Join(dir, "test_src") dst := filepath.Join(dir, "test_dest") - err := checkpoint.CopyFile(fs, src, dst) - require.ErrorIs(t, err, os.ErrNotExist) + err := checkpoint.CopyFile(aferoFS, src, dst) + require.ErrorIs(t, err, fs.ErrNotExist) // create src file - require.NoError(t, afero.WriteFile(fs, src, []byte("blah"), 0o600)) - err = checkpoint.CopyFile(fs, src, dst) + require.NoError(t, afero.WriteFile(aferoFS, src, []byte("blah"), 0o600)) + err = checkpoint.CopyFile(aferoFS, src, dst) require.NoError(t, err) // dst file cannot be copied over - err = checkpoint.CopyFile(fs, src, dst) - require.ErrorIs(t, err, os.ErrExist) + err = checkpoint.CopyFile(aferoFS, src, dst) + require.ErrorIs(t, err, fs.ErrExist) } diff --git a/go.mod b/go.mod index 4a6b678112..4a8020d898 100644 --- a/go.mod +++ b/go.mod @@ -52,10 +52,10 @@ require ( github.com/zeebo/blake3 v0.2.3 go.uber.org/mock v0.4.0 go.uber.org/zap v1.27.0 - golang.org/x/exp v0.0.0-20240119083558-1b970713d09a + golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 golang.org/x/sync v0.6.0 golang.org/x/time v0.5.0 - google.golang.org/genproto/googleapis/rpc v0.0.0-20240213162025-012b6fc9bca9 + google.golang.org/genproto/googleapis/rpc v0.0.0-20240221002015-b0ce06bbee7c google.golang.org/grpc v1.62.0 google.golang.org/protobuf v1.32.0 k8s.io/api v0.29.2 @@ -200,13 +200,13 @@ require ( go.uber.org/fx v1.20.1 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.19.0 // indirect - golang.org/x/mod v0.14.0 // indirect + golang.org/x/mod v0.15.0 // indirect golang.org/x/net v0.21.0 // indirect golang.org/x/oauth2 v0.17.0 // indirect golang.org/x/sys v0.17.0 // indirect golang.org/x/term v0.17.0 // indirect golang.org/x/text v0.14.0 // indirect - golang.org/x/tools v0.17.0 // indirect + golang.org/x/tools v0.18.0 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect gonum.org/v1/gonum v0.13.0 // indirect google.golang.org/api v0.166.0 // indirect diff --git a/go.sum b/go.sum index 9e9c40bf16..62df026274 100644 --- a/go.sum +++ b/go.sum @@ -764,8 +764,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= -golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= +golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ= +golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -789,8 +789,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= -golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.15.0 h1:SernR4v+D55NyBH2QiEQrlBAnj1ECL6AGrA5+dPaMY8= +golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -981,8 +981,8 @@ golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc= -golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= +golang.org/x/tools v0.18.0 h1:k8NLag8AGHnn+PHbl7g43CtqZAwG60vZkLqgyZgIHgQ= +golang.org/x/tools v0.18.0/go.mod h1:GL7B4CwcLLeo59yx/9UWWuNOW1n3VZ4f5axWfML7Lcg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -1063,8 +1063,8 @@ google.golang.org/genproto v0.0.0-20240213162025-012b6fc9bca9 h1:9+tzLLstTlPTRyJ google.golang.org/genproto v0.0.0-20240213162025-012b6fc9bca9/go.mod h1:mqHbVIp48Muh7Ywss/AD6I5kNVKZMmAa/QEW58Gxp2s= google.golang.org/genproto/googleapis/api v0.0.0-20240221002015-b0ce06bbee7c h1:9g7erC9qu44ks7UK4gDNlnk4kOxZG707xKm4jVniy6o= google.golang.org/genproto/googleapis/api v0.0.0-20240221002015-b0ce06bbee7c/go.mod h1:5iCWqnniDlqZHrd3neWVTOwvh/v6s3232omMecelax8= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240213162025-012b6fc9bca9 h1:hZB7eLIaYlW9qXRfCq/qDaPdbeY3757uARz5Vvfv+cY= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240213162025-012b6fc9bca9/go.mod h1:YUWgXUFRPfoYK1IHMuxH5K6nPEXSCzIMljnQ59lLRCk= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240221002015-b0ce06bbee7c h1:NUsgEN92SQQqzfA+YtqYNqYmB3DMMYLlIwUZAQFVFbo= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240221002015-b0ce06bbee7c/go.mod h1:H4O17MA/PE9BsGx3w+a+W2VOLLD1Qf7oJneAoU6WktY= google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= diff --git a/hare3/hare.go b/hare3/hare.go index d9b242d9f3..0bd802ec09 100644 --- a/hare3/hare.go +++ b/hare3/hare.go @@ -218,11 +218,11 @@ type Hare struct { tracer Tracer } -func (h *Hare) Register(signer *signing.EdSigner) { +func (h *Hare) Register(sig *signing.EdSigner) { h.mu.Lock() defer h.mu.Unlock() - h.log.Info("register signing key", zap.Stringer("node", signer.NodeID())) - h.signers[string(signer.NodeID().Bytes())] = signer + h.log.Info("registered signing key", log.ZShortStringer("id", sig.NodeID())) + h.signers[string(sig.NodeID().Bytes())] = sig } func (h *Hare) Results() <-chan ConsensusOutput { diff --git a/malfeasance/handler.go b/malfeasance/handler.go index 08ac99820f..99353b092f 100644 --- a/malfeasance/handler.go +++ b/malfeasance/handler.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "time" "github.com/spacemeshos/post/shared" @@ -33,7 +34,7 @@ type Handler struct { logger log.Log cdb *datastore.CachedDB self p2p.Peer - nodeID types.NodeID + nodeIDs []types.NodeID edVerifier SigVerifier tortoise tortoise postVerifier postVerifier @@ -43,7 +44,7 @@ func NewHandler( cdb *datastore.CachedDB, lg log.Log, self p2p.Peer, - nodeID types.NodeID, + nodeID []types.NodeID, edVerifier SigVerifier, tortoise tortoise, postVerifier postVerifier, @@ -52,7 +53,7 @@ func NewHandler( logger: lg, cdb: cdb, self: self, - nodeID: nodeID, + nodeIDs: nodeID, edVerifier: edVerifier, tortoise: tortoise, postVerifier: postVerifier, @@ -62,7 +63,7 @@ func NewHandler( func (h *Handler) reportMalfeasance(smesher types.NodeID, mp *types.MalfeasanceProof) { h.tortoise.OnMalfeasance(smesher) events.ReportMalfeasance(smesher, mp) - if h.nodeID == smesher { + if slices.Contains(h.nodeIDs, smesher) { events.EmitOwnMalfeasanceProof(smesher, mp) } } diff --git a/malfeasance/handler_test.go b/malfeasance/handler_test.go index 392306fd4e..522ce5bbc5 100644 --- a/malfeasance/handler_test.go +++ b/malfeasance/handler_test.go @@ -56,7 +56,7 @@ func TestHandler_HandleMalfeasanceProof_multipleATXs(t *testing.T) { datastore.NewCachedDB(db, lg), lg, "self", - types.EmptyNodeID, + []types.NodeID{types.RandomNodeID()}, signing.NewEdVerifier(), trt, postVerifier, @@ -270,7 +270,7 @@ func TestHandler_HandleMalfeasanceProof_multipleBallots(t *testing.T) { datastore.NewCachedDB(db, lg), lg, "self", - types.EmptyNodeID, + []types.NodeID{types.RandomNodeID()}, signing.NewEdVerifier(), trt, postVerifier, @@ -491,7 +491,7 @@ func TestHandler_HandleMalfeasanceProof_hareEquivocation(t *testing.T) { datastore.NewCachedDB(db, lg), lg, "self", - types.EmptyNodeID, + []types.NodeID{types.RandomNodeID()}, signing.NewEdVerifier(), trt, postVerifier, @@ -727,7 +727,7 @@ func TestHandler_HandleMalfeasanceProof_validateHare(t *testing.T) { datastore.NewCachedDB(db, lg), lg, "self", - types.EmptyNodeID, + []types.NodeID{types.RandomNodeID()}, signing.NewEdVerifier(), trt, postVerifier, @@ -789,7 +789,7 @@ func TestHandler_CrossDomain(t *testing.T) { datastore.NewCachedDB(db, lg), lg, "self", - types.EmptyNodeID, + []types.NodeID{types.RandomNodeID()}, signing.NewEdVerifier(), trt, postVerifier, @@ -853,7 +853,7 @@ func TestHandler_HandleSyncedMalfeasanceProof_multipleATXs(t *testing.T) { datastore.NewCachedDB(db, lg), lg, "self", - types.EmptyNodeID, + []types.NodeID{types.RandomNodeID()}, signing.NewEdVerifier(), trt, postVerifier, @@ -916,7 +916,7 @@ func TestHandler_HandleSyncedMalfeasanceProof_multipleBallots(t *testing.T) { datastore.NewCachedDB(db, lg), lg, "self", - types.EmptyNodeID, + []types.NodeID{types.RandomNodeID()}, signing.NewEdVerifier(), trt, postVerifier, @@ -978,7 +978,7 @@ func TestHandler_HandleSyncedMalfeasanceProof_hareEquivocation(t *testing.T) { datastore.NewCachedDB(db, lg), lg, "self", - types.EmptyNodeID, + []types.NodeID{types.RandomNodeID()}, signing.NewEdVerifier(), trt, postVerifier, @@ -1043,7 +1043,7 @@ func TestHandler_HandleSyncedMalfeasanceProof_wrongHash(t *testing.T) { datastore.NewCachedDB(db, lg), lg, "self", - types.EmptyNodeID, + []types.NodeID{types.RandomNodeID()}, signing.NewEdVerifier(), trt, postVerifier, @@ -1125,7 +1125,7 @@ func TestHandler_HandleMalfeasanceProof_InvalidPostIndex(t *testing.T) { datastore.NewCachedDB(db, lg), lg, "self", - types.EmptyNodeID, + []types.NodeID{types.RandomNodeID()}, signing.NewEdVerifier(), trt, postVerifier, @@ -1162,7 +1162,7 @@ func TestHandler_HandleMalfeasanceProof_InvalidPostIndex(t *testing.T) { datastore.NewCachedDB(db, lg), lg, "self", - types.EmptyNodeID, + []types.NodeID{types.RandomNodeID()}, signing.NewEdVerifier(), trt, postVerifier, @@ -1198,7 +1198,7 @@ func TestHandler_HandleMalfeasanceProof_InvalidPostIndex(t *testing.T) { datastore.NewCachedDB(db, lg), lg, "self", - types.EmptyNodeID, + []types.NodeID{types.RandomNodeID()}, signing.NewEdVerifier(), trt, postVerifier, diff --git a/miner/proposal_builder.go b/miner/proposal_builder.go index 16921d99fa..643c7d73c8 100644 --- a/miner/proposal_builder.go +++ b/miner/proposal_builder.go @@ -277,14 +277,15 @@ func New( return pb } -func (pb *ProposalBuilder) Register(signer *signing.EdSigner) { +func (pb *ProposalBuilder) Register(sig *signing.EdSigner) { pb.signers.mu.Lock() defer pb.signers.mu.Unlock() - _, exist := pb.signers.signers[signer.NodeID()] + _, exist := pb.signers.signers[sig.NodeID()] if !exist { - pb.signers.signers[signer.NodeID()] = &signerSession{ - signer: signer, - log: pb.logger.WithFields(log.String("signer", signer.NodeID().ShortString())), + pb.logger.With().Info("registered signing key", log.ShortStringer("id", sig.NodeID())) + pb.signers.signers[sig.NodeID()] = &signerSession{ + signer: sig, + log: pb.logger.WithFields(log.String("signer", sig.NodeID().ShortString())), } } } diff --git a/node/bad_peer_test.go b/node/bad_peer_test.go index 0b2d338aa3..e3cb3bf4a3 100644 --- a/node/bad_peer_test.go +++ b/node/bad_peer_test.go @@ -61,7 +61,6 @@ func TestPeerDisconnectForMessageResultValidationReject(t *testing.T) { <-app1.Started() t.Cleanup(func() { app1.Cleanup(ctx) - app1.eg.Wait() }) eg.Go(func() error { return app2.Start(grpContext) @@ -69,7 +68,6 @@ func TestPeerDisconnectForMessageResultValidationReject(t *testing.T) { <-app2.Started() t.Cleanup(func() { app2.Cleanup(ctx) - app2.eg.Wait() }) // Connect app2 to app1 diff --git a/node/node.go b/node/node.go index db5b5d1163..8897ea2ce9 100644 --- a/node/node.go +++ b/node/node.go @@ -4,9 +4,9 @@ package node import ( "bytes" "context" - "encoding/hex" "errors" "fmt" + "io/fs" "net" "net/http" "net/url" @@ -86,7 +86,6 @@ import ( ) const ( - edKeyFileName = "key.bin" genesisFileName = "genesis.json" dbFile = "state.sql" localDbFile = "node_state.sql" @@ -181,10 +180,27 @@ func GetCommand() *cobra.Command { return fmt.Errorf("initializing app: %w", err) } - /* Create or load miner identity */ + // Migrate legacy identity to new location + if err := app.MigrateExistingIdentity(); err != nil { + return fmt.Errorf("migrating existing identity: %w", err) + } + var err error - if app.edSgn, err = app.LoadOrCreateEdSigner(); err != nil { - return fmt.Errorf("could not retrieve identity: %w", err) + if app.signers, err = app.TestIdentity(); err != nil { + return fmt.Errorf("testing identity: %w", err) + } + + if app.signers == nil { + err := app.LoadIdentities() + switch { + case errors.Is(err, fs.ErrNotExist): + app.log.Info("Identity file not found. Creating new identity...") + if err := app.NewIdentity(); err != nil { + return fmt.Errorf("creating new identity: %w", err) + } + case err != nil: + return fmt.Errorf("loading identities: %w", err) + } } // Don't print usage on error from this point forward @@ -206,7 +222,6 @@ func GetCommand() *cobra.Command { // FIXME: per https://github.com/spacemeshos/go-spacemesh/issues/3830 go func() { app.Cleanup(cleanupCtx) - _ = app.eg.Wait() close(done) }() select { @@ -356,7 +371,7 @@ func New(opts ...Option) *App { type App struct { *cobra.Command fileLock *flock.Flock - edSgn *signing.EdSigner + signers []*signing.EdSigner Config *config.Config db *sql.Database cachedDB *datastore.CachedDB @@ -417,13 +432,14 @@ func (app *App) LoadCheckpoint(ctx context.Context) (*checkpoint.PreservedData, if restore == 0 { return nil, fmt.Errorf("restore layer not set") } + nodeIDs := make([]types.NodeID, 0, len(app.signers)) cfg := &checkpoint.RecoverConfig{ GoldenAtx: types.ATXID(app.Config.Genesis.GoldenATX()), DataDir: app.Config.DataDir(), DbFile: dbFile, LocalDbFile: localDbFile, PreserveOwnAtx: app.Config.Recovery.PreserveOwnAtx, - NodeID: app.edSgn.NodeID(), + NodeIDs: nodeIDs, Uri: checkpointFile, Restore: restore, } @@ -440,11 +456,11 @@ func (app *App) Started() <-chan struct{} { // Lock locks the app for exclusive use. It returns an error if the app is already locked. func (app *App) Lock() error { - lockdir := filepath.Dir(app.Config.FileLock) - if _, err := os.Stat(lockdir); errors.Is(err, os.ErrNotExist) { - err := os.Mkdir(lockdir, os.ModePerm) + lockDir := filepath.Dir(app.Config.FileLock) + if _, err := os.Stat(lockDir); errors.Is(err, fs.ErrNotExist) { + err := os.Mkdir(lockDir, os.ModePerm) if err != nil { - return fmt.Errorf("creating dir %s for lock %s: %w", lockdir, app.Config.FileLock, err) + return fmt.Errorf("creating dir %s for lock %s: %w", lockDir, app.Config.FileLock, err) } } fl := flock.New(app.Config.FileLock) @@ -476,7 +492,7 @@ func (app *App) Initialize() error { gpath := filepath.Join(app.Config.DataDir(), genesisFileName) var existing config.GenesisConfig if err := existing.LoadFromFile(gpath); err != nil { - if !errors.Is(err, os.ErrNotExist) { + if !errors.Is(err, fs.ErrNotExist) { return fmt.Errorf("failed to load genesis config at %s: %w", gpath, err) } if err := app.Config.Genesis.Validate(); err != nil { @@ -530,7 +546,7 @@ func (app *App) getAppInfo() string { func (app *App) Cleanup(ctx context.Context) { app.log.Info("app cleanup starting...") app.stopServices(ctx) - // add any other Cleanup tasks here.... + app.eg.Wait() app.log.Info("app cleanup completed") } @@ -591,12 +607,18 @@ func (app *App) initServices(ctx context.Context) error { poetDb := activation.NewPoetDb(app.db, app.addLogger(PoetDbLogger, lg)) + opts := []activation.PostVerifierOpt{ + activation.WithVerifyingOpts(app.Config.SMESHING.VerifyingOpts), + activation.WithAutoscaling(), + } + for _, sig := range app.signers { + opts = append(opts, activation.WithPrioritizedID(sig.NodeID())) + } + verifier, err := activation.NewPostVerifier( app.Config.POST, app.addLogger(NipostValidatorLogger, lg).Zap(), - activation.WithVerifyingOpts(app.Config.SMESHING.VerifyingOpts), - activation.PrioritizedIDs(app.edSgn.NodeID()), - activation.WithAutoscaling(), + opts..., ) if err != nil { return fmt.Errorf("creating post verifier: %w", err) @@ -662,7 +684,9 @@ func (app *App) initServices(ctx context.Context) error { beacon.WithConfig(app.Config.Beacon), beacon.WithLogger(app.addLogger(BeaconLogger, lg)), ) - beaconProtocol.Register(app.edSgn) + for _, sig := range app.signers { + beaconProtocol.Register(sig) + } trtlCfg := app.Config.Tortoise trtlCfg.LayerSize = layerSize @@ -736,7 +760,9 @@ func (app *App) initServices(ctx context.Context) error { trtl, app.addLogger(ATXHandlerLogger, lg), ) - atxHandler.Register(app.edSgn) + for _, sig := range app.signers { + atxHandler.Register(sig) + } // we can't have an epoch offset which is greater/equal than the number of layers in an epoch @@ -795,7 +821,9 @@ func (app *App) initServices(ctx context.Context) error { blocks.WithCertConfig(app.Config.Certificate), blocks.WithCertifierLogger(app.addLogger(BlockCertLogger, lg)), ) - app.certifier.Register(app.edSgn) + for _, sig := range app.signers { + app.certifier.Register(sig) + } proposalsStore := store.New( store.WithEvictedLayer(app.clock.CurrentLayer()), @@ -855,7 +883,9 @@ func (app *App) initServices(ctx context.Context) error { hare3.WithLogger(logger), hare3.WithConfig(app.Config.HARE3), ) - app.hare3.Register(app.edSgn) + for _, sig := range app.signers { + app.hare3.Register(sig) + } app.hare3.Start() app.eg.Go(func() error { compat.ReportWeakcoin( @@ -927,7 +957,9 @@ func (app *App) initServices(ctx context.Context) error { miner.WithMinGoodAtxPercent(minerGoodAtxPct), miner.WithLogger(app.addLogger(ProposalBuilderLogger, lg)), ) - proposalBuilder.Register(app.edSgn) + for _, sig := range app.signers { + proposalBuilder.Register(sig) + } host, port, err := net.SplitHostPort(app.Config.API.PostListener) if err != nil { @@ -940,7 +972,6 @@ func (app *App) initServices(ctx context.Context) error { app.Config.POSTService.NodeAddress = fmt.Sprintf("http://%s:%s", host, port) postSetupMgr, err := activation.NewPostSetupManager( - app.edSgn.NodeID(), app.Config.POST, app.addLogger(PostLogger, lg).Zap(), app.cachedDB, @@ -983,6 +1014,7 @@ func (app *App) initServices(ctx context.Context) error { builderConfig := activation.Config{ GoldenATXID: goldenATXID, + LabelsPerUnit: app.Config.POST.LabelsPerUnit, RegossipInterval: app.Config.RegossipAtxInterval, } atxBuilder := activation.NewBuilder( @@ -1001,13 +1033,19 @@ func (app *App) initServices(ctx context.Context) error { activation.WithValidator(app.validator), activation.WithPostValidityDelay(app.Config.PostValidDelay), ) - atxBuilder.Register(app.edSgn) + for _, sig := range app.signers { + atxBuilder.Register(sig) + } + nodeIDs := make([]types.NodeID, 0, len(app.signers)) + for _, s := range app.signers { + nodeIDs = append(nodeIDs, s.NodeID()) + } malfeasanceHandler := malfeasance.NewHandler( app.cachedDB, app.addLogger(MalfeasanceLogger, lg), app.host.ID(), - app.edSgn.NodeID(), + nodeIDs, app.edVerifier, trtl, app.postVerifier, @@ -1297,7 +1335,10 @@ func (app *App) startServices(ctx context.Context) error { if app.Config.SMESHING.CoinbaseAccount == "" { return fmt.Errorf("smeshing enabled but no coinbase account provided") } - if err := app.postSupervisor.Start(app.Config.SMESHING.Opts); err != nil { + if len(app.signers) > 1 { + return fmt.Errorf("supervised smeshing cannot be started in a multi-smeshing setup") + } + if err := app.postSupervisor.Start(app.Config.SMESHING.Opts, app.signers[0].NodeID()); err != nil { return fmt.Errorf("start post service: %w", err) } } else { @@ -1358,10 +1399,16 @@ func (app *App) grpcService(svc grpcserver.Service, lg log.Log) (grpcserver.Serv app.grpcServices[svc] = service return service, nil case grpcserver.Smesher: + var nodeID *types.NodeID + if len(app.signers) == 1 { + nodeID = new(types.NodeID) + *nodeID = app.signers[0].NodeID() + } service := grpcserver.NewSmesherService( app.atxBuilder, app.postSupervisor, app.Config.API.SmesherStreamInterval, + nodeID, app.Config.SMESHING.Opts, ) app.grpcServices[svc] = service @@ -1648,64 +1695,6 @@ func (app *App) stopServices(ctx context.Context) { grpczap.SetGrpcLoggerV2(grpclog, log.NewNop().Zap()) } -// LoadOrCreateEdSigner either loads a previously created ed identity for the node or creates a new one if not exists. -func (app *App) LoadOrCreateEdSigner() (*signing.EdSigner, error) { - filename := filepath.Join(app.Config.SMESHING.Opts.DataDir, edKeyFileName) - app.log.Info("Looking for identity file at `%v`", filename) - - var data []byte - if len(app.Config.TestConfig.SmesherKey) > 0 { - app.log.With().Error("!!!TESTING!!! using pre-configured smesher key") - data = []byte(app.Config.TestConfig.SmesherKey) - } else { - var err error - data, err = os.ReadFile(filename) - if err != nil { - if !os.IsNotExist(err) { - return nil, fmt.Errorf("failed to read identity file: %w", err) - } - - app.log.Info("Identity file not found. Creating new identity...") - - edSgn, err := signing.NewEdSigner( - signing.WithPrefix(app.Config.Genesis.GenesisID().Bytes()), - ) - if err != nil { - return nil, fmt.Errorf("failed to create identity: %w", err) - } - if err := os.MkdirAll(filepath.Dir(filename), 0o700); err != nil { - return nil, fmt.Errorf("failed to create directory for identity file: %w", err) - } - - err = os.WriteFile(filename, []byte(hex.EncodeToString(edSgn.PrivateKey())), 0o600) - if err != nil { - return nil, fmt.Errorf("failed to write identity file: %w", err) - } - - app.log.With().Info("created new identity", edSgn.PublicKey()) - return edSgn, nil - } - } - dst := make([]byte, signing.PrivateKeySize) - n, err := hex.Decode(dst, data) - if err != nil { - return nil, fmt.Errorf("decoding private key: %w", err) - } - if n != signing.PrivateKeySize { - return nil, fmt.Errorf("invalid key size %d/%d", n, signing.PrivateKeySize) - } - edSgn, err := signing.NewEdSigner( - signing.WithPrivateKey(dst), - signing.WithPrefix(app.Config.Genesis.GenesisID().Bytes()), - ) - if err != nil { - return nil, fmt.Errorf("failed to construct identity from data file: %w", err) - } - - app.log.Info("Loaded existing identity; public key: %v", edSgn.PublicKey()) - return edSgn, nil -} - func (app *App) setupDBs(ctx context.Context, lg log.Log) error { dbPath := app.Config.DataDir() if err := os.MkdirAll(dbPath, os.ModePerm); err != nil { diff --git a/node/node_identities.go b/node/node_identities.go new file mode 100644 index 0000000000..407e9eb599 --- /dev/null +++ b/node/node_identities.go @@ -0,0 +1,209 @@ +package node + +import ( + "encoding/hex" + "errors" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + + "github.com/spacemeshos/go-spacemesh/log" + "github.com/spacemeshos/go-spacemesh/signing" +) + +const ( + legacyKeyFileName = "key.bin" + keyDir = "identities" + supervisedIDKeyFileName = "identity.key" +) + +// TestIdentity loads a pre-configured identity for testing purposes. +func (app *App) TestIdentity() ([]*signing.EdSigner, error) { + if len(app.Config.TestConfig.SmesherKey) == 0 { + return nil, nil + } + + app.log.With().Error("!!!TESTING!!! using pre-configured smesher key") + dst, err := hex.DecodeString(app.Config.TestConfig.SmesherKey) + if err != nil { + return nil, fmt.Errorf("decoding private key: %w", err) + } + if len(dst) != signing.PrivateKeySize { + return nil, fmt.Errorf("invalid key size %d/%d", dst, signing.PrivateKeySize) + } + signer, err := signing.NewEdSigner( + signing.WithPrivateKey(dst), + signing.WithPrefix(app.Config.Genesis.GenesisID().Bytes()), + ) + if err != nil { + return nil, fmt.Errorf("failed to construct identity from data file: %w", err) + } + + app.log.With().Info("Loaded testing identity", signer.PublicKey()) + return []*signing.EdSigner{signer}, nil +} + +// MigrateExistingIdentity migrates the legacy identity file to the new location. +// +// The legacy identity file is expected to be located at `app.Config.SMESHING.Opts.DataDir/key.bin`. +// +// TODO(mafa): this can be removed in a future version when the legacy identity file is no longer expected to exist. +func (app *App) MigrateExistingIdentity() error { + oldKey := filepath.Join(app.Config.SMESHING.Opts.DataDir, legacyKeyFileName) + app.log.Info("Looking for legacy identity file at `%v`", oldKey) + + src, err := os.Open(oldKey) + switch { + case errors.Is(err, os.ErrNotExist): + app.log.Info("Legacy identity file not found.") + return nil + case err != nil: + return fmt.Errorf("failed to open legacy identity file: %w", err) + } + defer src.Close() + + newKey := filepath.Join(app.Config.DataDir(), keyDir, supervisedIDKeyFileName) + if err := os.MkdirAll(filepath.Dir(newKey), 0o700); err != nil { + return fmt.Errorf("failed to create directory for identity file: %w", err) + } + + dst, err := os.Create(newKey) + if err != nil { + return fmt.Errorf("failed to create new identity file: %w", err) + } + defer dst.Close() + + if _, err := io.Copy(dst, src); err != nil { + return fmt.Errorf("failed to copy identity file: %w", err) + } + + if err := src.Close(); err != nil { + return fmt.Errorf("failed to close legacy identity file: %w", err) + } + + if err := os.Rename(oldKey, oldKey+".bak"); err != nil { + return fmt.Errorf("failed to rename legacy identity file: %w", err) + } + + app.log.Info("Migrated legacy identity file to `%v`", newKey) + return nil +} + +// NewIdentity creates a new identity, saves it to `keyDir/supervisedIDKeyFileName` in the config directory and +// initializes app.signers with that identity. +func (app *App) NewIdentity() error { + dir := filepath.Join(app.Config.DataDir(), keyDir) + if err := os.MkdirAll(dir, 0o700); err != nil { + return fmt.Errorf("failed to create directory for identity file: %w", err) + } + + signer, err := signing.NewEdSigner( + signing.WithPrefix(app.Config.Genesis.GenesisID().Bytes()), + ) + if err != nil { + return fmt.Errorf("failed to create identity: %w", err) + } + + keyFile := filepath.Join(dir, supervisedIDKeyFileName) + if _, err := os.Stat(keyFile); err == nil { + return fmt.Errorf("identity file %s already exists: %w", supervisedIDKeyFileName, fs.ErrExist) + } + + dst := make([]byte, hex.EncodedLen(len(signer.PrivateKey()))) + hex.Encode(dst, signer.PrivateKey()) + err = os.WriteFile(keyFile, dst, 0o600) + if err != nil { + return fmt.Errorf("failed to write identity file: %w", err) + } + + app.log.With().Info("Created new identity", + log.String("filename", supervisedIDKeyFileName), + signer.PublicKey(), + ) + app.signers = []*signing.EdSigner{signer} + return nil +} + +// LoadIdentities loads all existing identities from the config directory. +func (app *App) LoadIdentities() error { + signers := make([]*signing.EdSigner, 0) + pubKeys := make(map[string]*signing.PublicKey) + + dir := filepath.Join(app.Config.DataDir(), keyDir) + err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return fmt.Errorf("failed to walk directory at %s: %w", path, err) + } + + // skip subdirectories and files in them + if d.IsDir() && path != dir { + return fs.SkipDir + } + + // skip files that are not identity files + if filepath.Ext(path) != ".key" { + return nil + } + + // read hex data from file + dst := make([]byte, signing.PrivateKeySize) + data, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("failed to open identity file at %s: %w", path, err) + } + + n, err := hex.Decode(dst, data) + if err != nil { + return fmt.Errorf("decoding private key in %s: %w", d.Name(), err) + } + if n != signing.PrivateKeySize { + return fmt.Errorf("invalid key size %d/%d for %s", n, signing.PrivateKeySize, d.Name()) + } + + signer, err := signing.NewEdSigner( + signing.WithPrivateKey(dst), + signing.WithPrefix(app.Config.Genesis.GenesisID().Bytes()), + ) + if err != nil { + return fmt.Errorf("failed to construct identity %s: %w", d.Name(), err) + } + + app.log.With().Info("Loaded existing identity", + log.String("filename", d.Name()), + signer.PublicKey(), + ) + signers = append(signers, signer) + pubKeys[d.Name()] = signer.PublicKey() + return nil + }) + if err != nil { + return err + } + if len(signers) == 0 { + return fmt.Errorf("no identity files found: %w", fs.ErrNotExist) + } + + // make sure all public keys are unique + seen := make(map[string]string) + collision := false + for f1, pk := range pubKeys { + if f2, ok := seen[pk.String()]; ok { + app.log.With().Error("duplicate key", + log.String("filename1", f1), + log.String("filename2", f2), + pk, + ) + collision = true + continue + } + seen[pk.String()] = f1 + } + if collision { + return fmt.Errorf("duplicate key found in identity files") + } + + app.signers = signers + return nil +} diff --git a/node/node_identities_test.go b/node/node_identities_test.go new file mode 100644 index 0000000000..26583cddee --- /dev/null +++ b/node/node_identities_test.go @@ -0,0 +1,178 @@ +package node + +import ( + "bytes" + "encoding/hex" + "fmt" + "io/fs" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" + + "github.com/spacemeshos/go-spacemesh/log" + "github.com/spacemeshos/go-spacemesh/log/logtest" + "github.com/spacemeshos/go-spacemesh/signing" +) + +func setupAppWithKeys(tb testing.TB, data ...[]byte) (*App, *observer.ObservedLogs) { + observer, observedLogs := observer.New(zapcore.WarnLevel) + logger := zap.New(observer) + app := New(WithLog(log.NewFromLog(logger))) + app.Config.DataDirParent = tb.TempDir() + if len(data) == 0 { + return app, observedLogs + } + + key := data[0] + keyFile := filepath.Join(app.Config.DataDirParent, keyDir, supervisedIDKeyFileName) + require.NoError(tb, os.MkdirAll(filepath.Dir(keyFile), 0o700)) + require.NoError(tb, os.WriteFile(keyFile, key, 0o600)) + + for i, key := range data[1:] { + keyFile = filepath.Join(app.Config.DataDirParent, keyDir, fmt.Sprintf("identity_%d.key", i)) + require.NoError(tb, os.WriteFile(keyFile, key, 0o600)) + } + return app, observedLogs +} + +func TestSpacemeshApp_NewIdentity(t *testing.T) { + t.Run("no key", func(t *testing.T) { + app := New(WithLog(logtest.New(t))) + app.Config.DataDirParent = t.TempDir() + err := app.NewIdentity() + require.NoError(t, err) + require.Len(t, app.signers, 1) + }) + + t.Run("no key but existing directory", func(t *testing.T) { + app := New(WithLog(logtest.New(t))) + app.Config.DataDirParent = t.TempDir() + require.NoError(t, os.MkdirAll(filepath.Join(app.Config.DataDirParent, keyDir), 0o700)) + err := app.NewIdentity() + require.NoError(t, err) + require.Len(t, app.signers, 1) + }) + + t.Run("existing key is not overwritten", func(t *testing.T) { + signer, err := signing.NewEdSigner() + require.NoError(t, err) + before := signer.PublicKey() + + app, _ := setupAppWithKeys(t, []byte(hex.EncodeToString(signer.PrivateKey()))) + err = app.NewIdentity() + require.ErrorContains(t, err, fmt.Sprintf("identity file %s already exists", supervisedIDKeyFileName)) + require.ErrorIs(t, err, fs.ErrExist) + require.Empty(t, app.signers) + + err = app.LoadIdentities() + require.NoError(t, err) + require.NotEmpty(t, app.signers) + after := app.signers[0].PublicKey() + require.Equal(t, before, after) + }) + + t.Run("non default key is preserved on disk", func(t *testing.T) { + signer, err := signing.NewEdSigner() + require.NoError(t, err) + existingKey := signer.PublicKey() + + app, _ := setupAppWithKeys(t, []byte(hex.EncodeToString(signer.PrivateKey()))) + err = os.Rename( + filepath.Join(app.Config.DataDirParent, keyDir, supervisedIDKeyFileName), + filepath.Join(app.Config.DataDirParent, keyDir, "do_not_delete.key"), + ) + require.NoError(t, err) + + err = app.NewIdentity() + require.NoError(t, err) + require.Len(t, app.signers, 1) + newKey := app.signers[0].PublicKey() + require.NotEqual(t, existingKey, newKey) // new key was created and loaded + + err = app.LoadIdentities() + require.NoError(t, err) + require.Len(t, app.signers, 2) + require.Equal(t, app.signers[0].PublicKey(), existingKey) + require.Equal(t, app.signers[1].PublicKey(), newKey) + }) +} + +func TestSpacemeshApp_LoadIdentities(t *testing.T) { + t.Run("no key", func(t *testing.T) { + app := New(WithLog(logtest.New(t))) + app.Config.DataDirParent = t.TempDir() + err := app.LoadIdentities() + require.ErrorIs(t, err, fs.ErrNotExist) + require.Empty(t, app.signers) + }) + + t.Run("no key but existing directory", func(t *testing.T) { + app := New(WithLog(logtest.New(t))) + app.Config.DataDirParent = t.TempDir() + require.NoError(t, os.MkdirAll(filepath.Join(app.Config.DataDirParent, keyDir), 0o700)) + err := app.LoadIdentities() + require.ErrorIs(t, err, fs.ErrNotExist) + require.Empty(t, app.signers) + }) + + t.Run("good key", func(t *testing.T) { + signer, err := signing.NewEdSigner() + require.NoError(t, err) + + app, _ := setupAppWithKeys(t, []byte(hex.EncodeToString(signer.PrivateKey()))) + err = app.LoadIdentities() + require.NoError(t, err) + require.NotEmpty(t, app.signers) + before := app.signers[0].PublicKey() + + err = app.LoadIdentities() + require.NoError(t, err) + require.NotEmpty(t, app.signers) + after := app.signers[0].PublicKey() + require.Equal(t, before, after) + }) + + t.Run("bad length", func(t *testing.T) { + app, _ := setupAppWithKeys(t, bytes.Repeat([]byte("ab"), signing.PrivateKeySize-1)) + err := app.LoadIdentities() + require.ErrorContains(t, err, fmt.Sprintf("invalid key size 63/64 for %s", supervisedIDKeyFileName)) + require.Nil(t, app.signers) + }) + + t.Run("bad hex", func(t *testing.T) { + app, _ := setupAppWithKeys(t, bytes.Repeat([]byte("CV"), signing.PrivateKeySize)) + err := app.LoadIdentities() + require.ErrorContains(t, err, fmt.Sprintf("decoding private key in %s:", supervisedIDKeyFileName)) + require.ErrorIs(t, err, hex.InvalidByteError(byte('V'))) + require.Nil(t, app.signers) + }) + + t.Run("duplicate keys", func(t *testing.T) { + key1, err := signing.NewEdSigner() + require.NoError(t, err) + key2, err := signing.NewEdSigner() + require.NoError(t, err) + + app, observedLogs := setupAppWithKeys(t, + []byte(hex.EncodeToString(key1.PrivateKey())), + []byte(hex.EncodeToString(key2.PrivateKey())), + []byte(hex.EncodeToString(key1.PrivateKey())), + []byte(hex.EncodeToString(key2.PrivateKey())), + ) + err = app.LoadIdentities() + require.ErrorContains(t, err, "duplicate key") + + require.Len(t, observedLogs.All(), 2) + log1 := observedLogs.FilterField(zap.String("public_key", key1.PublicKey().ShortString())) + require.Len(t, log1.All(), 1) + require.Contains(t, log1.All()[0].Message, "duplicate key") + log2 := observedLogs.FilterField(zap.String("public_key", key2.PublicKey().ShortString())) + require.Len(t, log2.All(), 1) + require.Contains(t, log2.All()[0].Message, "duplicate key") + }) +} diff --git a/node/node_test.go b/node/node_test.go index fb307bfe23..325d9a9ca8 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -3,7 +3,6 @@ package node import ( "bytes" "context" - "crypto/ed25519" "encoding/hex" "fmt" "io" @@ -19,6 +18,7 @@ import ( mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" pb "github.com/spacemeshos/api/release/go/spacemesh/v1" "github.com/spacemeshos/post/initialization" + "github.com/spacemeshos/post/shared" "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/stretchr/testify/assert" @@ -59,82 +59,6 @@ func TestMain(m *testing.M) { os.Exit(res) } -func TestSpacemeshApp_getEdIdentity(t *testing.T) { - r := require.New(t) - - tempdir := t.TempDir() - - // setup spacemesh app - app := New(WithLog(logtest.New(t))) - app.Config.SMESHING.Opts.DataDir = tempdir - app.log = logtest.New(t) - - // Create new identity. - signer1, err := app.LoadOrCreateEdSigner() - r.NoError(err) - infos, err := os.ReadDir(tempdir) - r.NoError(err) - r.Len(infos, 1) - - // Load existing identity. - signer2, err := app.LoadOrCreateEdSigner() - r.NoError(err) - infos, err = os.ReadDir(tempdir) - r.NoError(err) - r.Len(infos, 1) - r.Equal(signer1.PublicKey(), signer2.PublicKey()) - - // Invalidate the identity by changing its file name. - filename := filepath.Join(tempdir, infos[0].Name()) - err = os.Rename(filename, filename+"_") - r.NoError(err) - - // Create new identity. - signer3, err := app.LoadOrCreateEdSigner() - r.NoError(err) - infos, err = os.ReadDir(tempdir) - r.NoError(err) - r.Len(infos, 2) - r.NotEqual(signer1.PublicKey(), signer3.PublicKey()) - - t.Run("bad length", func(t *testing.T) { - testLoadOrCreateEdSigner(t, - bytes.Repeat([]byte("ab"), signing.PrivateKeySize-1), - "invalid key size 63/64", - ) - }) - t.Run("bad hex", func(t *testing.T) { - testLoadOrCreateEdSigner(t, - bytes.Repeat([]byte("CV"), signing.PrivateKeySize), - "decoding private key: encoding/hex: invalid byte", - ) - }) - t.Run("good key", func(t *testing.T) { - _, priv, err := ed25519.GenerateKey(nil) - require.NoError(t, err) - testLoadOrCreateEdSigner(t, - []byte(hex.EncodeToString(priv)), - "", - ) - }) -} - -func testLoadOrCreateEdSigner(t *testing.T, data []byte, expect string) { - tempdir := t.TempDir() - app := New(WithLog(logtest.New(t))) - app.Config.SMESHING.Opts.DataDir = tempdir - keyfile := filepath.Join(app.Config.SMESHING.Opts.DataDir, edKeyFileName) - require.NoError(t, os.WriteFile(keyfile, data, 0o600)) - signer, err := app.LoadOrCreateEdSigner() - if len(expect) > 0 { - require.ErrorContains(t, err, expect) - require.Nil(t, signer) - } else { - require.NoError(t, err) - require.NotEmpty(t, signer) - } -} - func newLogger(buf *bytes.Buffer) log.Log { lvl := zap.NewAtomicLevelAt(zapcore.InfoLevel) syncer := zapcore.AddSync(buf) @@ -148,22 +72,23 @@ func newLogger(buf *bytes.Buffer) log.Log { func TestSpacemeshApp_SetLoggers(t *testing.T) { r := require.New(t) - var buf1, buf2 bytes.Buffer + cfg := getTestDefaultConfig(t) + app := New(WithConfig(cfg), WithLog(logtest.New(t))) - app := New(WithLog(logtest.New(t))) - mylogger := "anton" + var buf1, buf2 bytes.Buffer + myLogger := "anton" myLog := newLogger(&buf1) myLog2 := newLogger(&buf2) - app.log = app.addLogger(mylogger, myLog) + app.log = app.addLogger(myLogger, myLog) msg1 := "hi there" app.log.Info(msg1) r.Equal( - fmt.Sprintf("INFO\t%s\t%s\n", mylogger, msg1), + fmt.Sprintf("INFO\t%s\t%s\n", myLogger, msg1), buf1.String(), ) - r.NoError(app.SetLogLevel(mylogger, "warn")) - r.Equal("warn", app.loggers[mylogger].String()) + r.NoError(app.SetLogLevel(myLogger, "warn")) + r.Equal("warn", app.loggers[myLogger].String()) buf1.Reset() msg1 = "other logger" @@ -175,19 +100,19 @@ func TestSpacemeshApp_SetLoggers(t *testing.T) { // This one should be printed app.log.Warning(msg3) r.Equal( - fmt.Sprintf("WARN\t%s\t%s\n", mylogger, msg3), + fmt.Sprintf("WARN\t%s\t%s\n", myLogger, msg3), buf1.String(), ) r.Equal(fmt.Sprintf("INFO\t%s\n", msg1), buf2.String()) buf1.Reset() - r.NoError(app.SetLogLevel(mylogger, "info")) + r.NoError(app.SetLogLevel(myLogger, "info")) - msg4 := "nihao" + msg4 := "你好" app.log.Info(msg4) - r.Equal("info", app.loggers[mylogger].String()) + r.Equal("info", app.loggers[myLogger].String()) r.Equal( - fmt.Sprintf("INFO\t%s\t%s\n", mylogger, msg4), + fmt.Sprintf("INFO\t%s\t%s\n", myLogger, msg4), buf1.String(), ) @@ -195,24 +120,26 @@ func TestSpacemeshApp_SetLoggers(t *testing.T) { r.Error(app.SetLogLevel("anton3", "warn")) // test bad loglevel - r.Error(app.SetLogLevel(mylogger, "lulu")) - r.Equal("info", app.loggers[mylogger].String()) + r.Error(app.SetLogLevel(myLogger, "lulu")) + r.Equal("info", app.loggers[myLogger].String()) } func TestSpacemeshApp_AddLogger(t *testing.T) { r := require.New(t) var buf bytes.Buffer - lg := newLogger(&buf) - app := New(WithLog(logtest.New(t))) - mylogger := "anton" - subLogger := app.addLogger(mylogger, lg) + + cfg := getTestDefaultConfig(t) + app := New(WithConfig(cfg), WithLog(logtest.New(t))) + + myLogger := "anton" + subLogger := app.addLogger(myLogger, lg) subLogger.Debug("should not get printed") teststr := "should get printed" subLogger.Info(teststr) r.Equal( - fmt.Sprintf("INFO\t%s\t%s\n", mylogger, teststr), + fmt.Sprintf("INFO\t%s\t%s\n", myLogger, teststr), buf.String(), ) } @@ -238,7 +165,9 @@ func cmdWithRun(run func(*cobra.Command, []string) error) *cobra.Command { func TestSpacemeshApp_Cmd(t *testing.T) { r := require.New(t) - app := New(WithLog(logtest.New(t))) + + cfg := getTestDefaultConfig(t) + app := New(WithConfig(cfg), WithLog(logtest.New(t))) expected := `unknown command "illegal" for "node"` expected2 := "Error: " + expected + "\nRun 'node --help' for usage.\n" @@ -279,18 +208,20 @@ func TestSpacemeshApp_GrpcService(t *testing.T) { listener := "127.0.0.1:1242" r := require.New(t) - app := New(WithLog(logtest.New(t))) + cfg := getTestDefaultConfig(t) + cfg.API.PublicListener = listener + app := New(WithConfig(cfg), WithLog(logtest.New(t))) + err := app.NewIdentity() + require.NoError(t, err) run := func(c *cobra.Command, args []string) error { - app.Config.API.PublicListener = listener - app.Config.DataDirParent = t.TempDir() return app.startAPIServices(context.Background()) } defer app.stopServices(context.Background()) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - _, err := grpc.DialContext( + _, err = grpc.DialContext( ctx, listener, grpc.WithTransportCredentials(insecure.NewCredentials()), @@ -330,11 +261,13 @@ func TestSpacemeshApp_GrpcService(t *testing.T) { func TestSpacemeshApp_JsonServiceNotRunning(t *testing.T) { r := require.New(t) - app := New(WithLog(logtest.New(t))) + cfg := getTestDefaultConfig(t) + app := New(WithConfig(cfg), WithLog(logtest.New(t))) + err := app.NewIdentity() + require.NoError(t, err) // Make sure the service is not running by default run := func(c *cobra.Command, args []string) error { - app.Config.DataDirParent = t.TempDir() return app.startAPIServices(context.Background()) } @@ -348,16 +281,18 @@ func TestSpacemeshApp_JsonServiceNotRunning(t *testing.T) { func TestSpacemeshApp_JsonService(t *testing.T) { r := require.New(t) - app := New(WithLog(logtest.New(t))) - const message = "nihao shijie" + + const message = "你好世界" payload := marshalProto(t, &pb.EchoRequest{Msg: &pb.SimpleString{Value: message}}) listener := "127.0.0.1:0" + cfg := getTestDefaultConfig(t) + cfg.API.JSONListener = listener + cfg.API.PrivateServices = nil + app := New(WithConfig(cfg), WithLog(logtest.New(t))) + // Make sure the service is not running by default run := func(c *cobra.Command, args []string) error { - app.Config.API.PrivateServices = nil - app.Config.API.JSONListener = listener - app.Config.DataDirParent = t.TempDir() return app.startAPIServices(context.Background()) } @@ -394,16 +329,12 @@ func TestSpacemeshApp_NodeService(t *testing.T) { events.EventHook(), ) // errlog is used to simulate errors in the app - app := New(WithLog(logger)) - app.Config = getTestDefaultConfig(t) - types.SetNetworkHRP(app.Config.NetworkHRP) // ensure that the correct HRP is set when generating the address below - app.Config.SMESHING.CoinbaseAccount = types.GenerateAddress([]byte{1}).String() - app.Config.SMESHING.Opts.DataDir = t.TempDir() - app.Config.SMESHING.Opts.Scrypt.N = 2 + cfg := getTestDefaultConfig(t) + app := New(WithConfig(cfg), WithLog(logger)) - edSgn, err := signing.NewEdSigner() + signer, err := signing.NewEdSigner() require.NoError(t, err) - app.edSgn = edSgn + app.signers = []*signing.EdSigner{signer} mesh, err := mocknet.WithNPeers(1) require.NoError(t, err) @@ -411,8 +342,8 @@ func TestSpacemeshApp_NodeService(t *testing.T) { require.NoError(t, err) app.host = h - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + appCtx, appCancel := context.WithCancel(context.Background()) + defer appCancel() run := func(c *cobra.Command, args []string) error { // Give the error channel a buffer @@ -428,9 +359,12 @@ func TestSpacemeshApp_NodeService(t *testing.T) { // This will block. We need to run the full app here to make sure that // the various services are reporting events correctly. This could probably // be done more surgically, and we don't need _all_ of the services. - return app.Start(context.Background()) + return app.Start(appCtx) } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + // Run the app in a goroutine. As noted above, it blocks if it succeeds. // If there's an error in the args, it will return immediately. var eg errgroup.Group @@ -441,7 +375,7 @@ func TestSpacemeshApp_NodeService(t *testing.T) { return nil }) - <-app.started + <-app.Started() conn, err := grpc.DialContext( ctx, app.grpcPublicServer.BoundAddress, @@ -505,6 +439,7 @@ func TestSpacemeshApp_NodeService(t *testing.T) { // Cleanup stops all services and thereby the app <-app.Started() // prevents races when app is not started yet + appCancel() app.Cleanup(context.Background()) // Wait for everything to stop cleanly before ending test @@ -513,22 +448,22 @@ func TestSpacemeshApp_NodeService(t *testing.T) { // E2E app test of the transaction service. func TestSpacemeshApp_TransactionService(t *testing.T) { - r := require.New(t) - listener := "127.0.0.1:14236" - app := New(WithLog(logtest.New(t))) cfg := config.DefaultTestConfig() cfg.DataDirParent = t.TempDir() - app.Config = &cfg + app := New(WithConfig(&cfg), WithLog(logtest.New(t))) signer, err := signing.NewEdSigner() - r.NoError(err) - app.edSgn = signer + require.NoError(t, err) + app.signers = []*signing.EdSigner{signer} address := wallet.Address(signer.PublicKey().Bytes()) + appCtx, appCancel := context.WithCancel(context.Background()) + defer appCancel() + run := func(c *cobra.Command, args []string) error { - r.NoError(app.Initialize()) + require.NoError(t, app.Initialize()) // GRPC configuration app.Config.API.PublicListener = listener @@ -558,7 +493,7 @@ func TestSpacemeshApp_TransactionService(t *testing.T) { // This will block. We need to run the full app here to make sure that // the various services are reporting events correctly. This could probably // be done more surgically, and we don't need _all_ of the services. - return app.Start(context.Background()) + return app.Start(appCtx) } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -570,8 +505,8 @@ func TestSpacemeshApp_TransactionService(t *testing.T) { wg.Add(1) go func() { str, err := testArgs(ctx, cmdWithRun(run)) - r.Empty(str) - r.NoError(err) + require.Empty(t, str) + require.NoError(t, err) wg.Done() }() @@ -591,8 +526,8 @@ func TestSpacemeshApp_TransactionService(t *testing.T) { grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock(), ) - r.NoError(err) - t.Cleanup(func() { r.NoError(conn.Close()) }) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, conn.Close()) }) c := pb.NewTransactionServiceClient(conn) tx1 := types.NewRawTx( @@ -639,6 +574,7 @@ func TestSpacemeshApp_TransactionService(t *testing.T) { // This stops the app // Cleanup stops all services and thereby the app + appCancel() app.Cleanup(context.Background()) // Wait for it to stop @@ -917,9 +853,8 @@ func TestHRP(t *testing.T) { func TestGenesisConfig(t *testing.T) { t.Run("config is written to a file", func(t *testing.T) { - app := New() - app.Config = getTestDefaultConfig(t) - app.Config.DataDirParent = t.TempDir() + cfg := getTestDefaultConfig(t) + app := New(WithConfig(cfg)) require.NoError(t, app.Initialize()) t.Cleanup(func() { app.Cleanup(context.Background()) }) @@ -933,9 +868,8 @@ func TestGenesisConfig(t *testing.T) { }) t.Run("no error if no diff", func(t *testing.T) { - app := New() - app.Config = getTestDefaultConfig(t) - app.Config.DataDirParent = t.TempDir() + cfg := getTestDefaultConfig(t) + app := New(WithConfig(cfg)) require.NoError(t, app.Initialize()) app.Cleanup(context.Background()) @@ -945,9 +879,8 @@ func TestGenesisConfig(t *testing.T) { }) t.Run("fatal error on a diff", func(t *testing.T) { - app := New() - app.Config = getTestDefaultConfig(t) - app.Config.DataDirParent = t.TempDir() + cfg := getTestDefaultConfig(t) + app := New(WithConfig(cfg)) require.NoError(t, app.Initialize()) t.Cleanup(func() { app.Cleanup(context.Background()) }) @@ -959,18 +892,17 @@ func TestGenesisConfig(t *testing.T) { }) t.Run("not valid time", func(t *testing.T) { - app := New() - app.Config = getTestDefaultConfig(t) - app.Config.DataDirParent = t.TempDir() - app.Config.Genesis.GenesisTime = time.Now().Format(time.RFC1123) + cfg := getTestDefaultConfig(t) + cfg.Genesis.GenesisTime = time.Now().Format(time.RFC1123) + app := New(WithConfig(cfg)) require.ErrorContains(t, app.Initialize(), "time.RFC3339") }) + t.Run("long extra data", func(t *testing.T) { - app := New() - app.Config = getTestDefaultConfig(t) - app.Config.DataDirParent = t.TempDir() - app.Config.Genesis.ExtraData = string(make([]byte, 256)) + cfg := getTestDefaultConfig(t) + cfg.Genesis.ExtraData = string(make([]byte, 256)) + app := New(WithConfig(cfg)) require.ErrorContains(t, app.Initialize(), "extra-data") }) @@ -978,8 +910,8 @@ func TestGenesisConfig(t *testing.T) { func TestFlock(t *testing.T) { t.Run("sanity", func(t *testing.T) { - app := New() - app.Config = getTestDefaultConfig(t) + cfg := getTestDefaultConfig(t) + app := New(WithConfig(cfg)) require.NoError(t, app.Lock()) t.Cleanup(app.Unlock) @@ -991,9 +923,9 @@ func TestFlock(t *testing.T) { }) t.Run("dir doesn't exist", func(t *testing.T) { - app := New() - app.Config = getTestDefaultConfig(t) - app.Config.FileLock = filepath.Join(t.TempDir(), "newdir", "LOCK") + cfg := getTestDefaultConfig(t) + cfg.FileLock = filepath.Join(t.TempDir(), "newdir", "LOCK") + app := New(WithConfig(cfg)) require.NoError(t, app.Lock()) t.Cleanup(app.Unlock) @@ -1008,7 +940,7 @@ func TestAdminEvents(t *testing.T) { require.NoError(t, err) cfg.DataDirParent = t.TempDir() cfg.FileLock = filepath.Join(cfg.DataDirParent, "LOCK") - cfg.SMESHING.Opts.DataDir = cfg.DataDirParent + cfg.SMESHING.Opts.DataDir = t.TempDir() cfg.SMESHING.Opts.Scrypt.N = 2 cfg.POSTService.PostServiceCmd = activation.DefaultTestPostServiceConfig().PostServiceCmd @@ -1017,9 +949,8 @@ func TestAdminEvents(t *testing.T) { logger := logtest.New(t, zapcore.DebugLevel) app := New(WithConfig(&cfg), WithLog(logger)) - signer, err := app.LoadOrCreateEdSigner() - require.NoError(t, err) - app.edSgn = signer // https://github.com/spacemeshos/go-spacemesh/issues/4653 + + require.NoError(t, app.NewIdentity()) require.NoError(t, app.Initialize()) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -1029,7 +960,6 @@ func TestAdminEvents(t *testing.T) { return err } app.Cleanup(context.Background()) - app.eg.Wait() // https://github.com/spacemeshos/go-spacemesh/issues/4653 return nil }) t.Cleanup(func() { assert.NoError(t, eg.Wait()) }) @@ -1076,33 +1006,63 @@ func TestAdminEvents(t *testing.T) { } } -func TestAdminEvents_UnspecifiedAddresses(t *testing.T) { +func launchPostSupervisor( + tb testing.TB, + log *zap.Logger, + mgr *activation.PostSetupManager, + id types.NodeID, + address string, + postCfg activation.PostConfig, + postOpts activation.PostSetupOpts, +) func() { + cmdCfg := activation.DefaultTestPostServiceConfig() + cmdCfg.NodeAddress = fmt.Sprintf("http://%s", address) + provingOpts := activation.DefaultPostProvingOpts() + provingOpts.RandomXMode = activation.PostRandomXModeLight + + ps, err := activation.NewPostSupervisor(log, cmdCfg, postCfg, provingOpts, mgr) + require.NoError(tb, err) + require.NotNil(tb, ps) + require.NoError(tb, ps.Start(postOpts, id)) + return func() { assert.NoError(tb, ps.Stop(false)) } +} + +func TestAdminEvents_MultiSmesher(t *testing.T) { if testing.Short() { t.Skip() } cfg, err := presets.Get("standalone") require.NoError(t, err) - cfg.DataDirParent = t.TempDir() cfg.FileLock = filepath.Join(cfg.DataDirParent, "LOCK") - cfg.SMESHING.Opts.DataDir = cfg.DataDirParent cfg.SMESHING.Opts.Scrypt.N = 2 + cfg.SMESHING.Start = false cfg.POSTService.PostServiceCmd = activation.DefaultTestPostServiceConfig().PostServiceCmd - // Expose APIs on all interfaces - cfg.API.PublicListener = "0.0.0.0:10092" - cfg.API.PrivateListener = "0.0.0.0:10093" - cfg.API.PostListener = "0.0.0.0:10094" - cfg.Genesis.GenesisTime = time.Now().Add(5 * time.Second).Format(time.RFC3339) types.SetLayersPerEpoch(cfg.LayersPerEpoch) - logger := logtest.New(t, zapcore.DebugLevel) + logger := logtest.New(t) app := New(WithConfig(&cfg), WithLog(logger)) - signer, err := app.LoadOrCreateEdSigner() - require.NoError(t, err) - app.edSgn = signer // https://github.com/spacemeshos/go-spacemesh/issues/4653 + + dir := filepath.Join(app.Config.DataDir(), keyDir) + require.NoError(t, os.MkdirAll(dir, 0o700)) + for i := 0; i < 2; i++ { + signer, err := signing.NewEdSigner( + signing.WithPrefix(app.Config.Genesis.GenesisID().Bytes()), + ) + require.NoError(t, err) + + keyFile := filepath.Join(dir, fmt.Sprintf("node_%d.key", i)) + dst := make([]byte, hex.EncodedLen(len(signer.PrivateKey()))) + hex.Encode(dst, signer.PrivateKey()) + require.NoError(t, os.WriteFile(keyFile, dst, 0o600)) + } + + require.NoError(t, app.LoadIdentities()) + require.Len(t, app.signers, 2) require.NoError(t, app.Initialize()) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() var eg errgroup.Group @@ -1111,11 +1071,33 @@ func TestAdminEvents_UnspecifiedAddresses(t *testing.T) { return err } app.Cleanup(context.Background()) - app.eg.Wait() // https://github.com/spacemeshos/go-spacemesh/issues/4653 return nil }) t.Cleanup(func() { assert.NoError(t, eg.Wait()) }) + <-app.Started() + for _, signer := range app.signers { + mgr, err := activation.NewPostSetupManager( + cfg.POST, + logger.Zap(), + app.cachedDB, + types.ATXID(app.Config.Genesis.GoldenATX()), + app.syncer, + app.validator, + ) + require.NoError(t, err) + + cfg.SMESHING.Opts.DataDir = t.TempDir() + t.Cleanup(launchPostSupervisor(t, + logger.Zap(), + mgr, + signer.NodeID(), + cfg.API.PostListener, + cfg.POST, + cfg.SMESHING.Opts, + )) + } + grpcCtx, cancel := context.WithTimeout(ctx, 20*time.Second) defer cancel() conn, err := grpc.DialContext( @@ -1136,23 +1118,108 @@ func TestAdminEvents_UnspecifiedAddresses(t *testing.T) { for i := 0; i < 4; i++ { stream, err := client.EventsStream(tctx, &pb.EventStreamRequest{}) require.NoError(t, err) - success := []pb.IsEventDetails{ - &pb.Event_Beacon{}, - &pb.Event_InitStart{}, - &pb.Event_InitComplete{}, - &pb.Event_PostServiceStarted{}, - &pb.Event_PostStart{}, - &pb.Event_PostComplete{}, - &pb.Event_PoetWaitRound{}, - &pb.Event_PoetWaitProof{}, - &pb.Event_PostStart{}, - &pb.Event_PostComplete{}, - &pb.Event_AtxPublished{}, + + matchers := map[int]func(pb.IsEventDetails) bool{ + 0: func(ev pb.IsEventDetails) bool { + _, ok := ev.(*pb.Event_Beacon) + return ok + }, + 1: func(ev pb.IsEventDetails) bool { + startEv, ok := ev.(*pb.Event_PostStart) + if !ok { + return false + } + return bytes.Equal(startEv.PostStart.Smesher, app.signers[0].NodeID().Bytes()) && + bytes.Equal(startEv.PostStart.Challenge, shared.ZeroChallenge) + }, + 2: func(ev pb.IsEventDetails) bool { + completeEv, ok := ev.(*pb.Event_PostComplete) + if !ok { + return false + } + return bytes.Equal(completeEv.PostComplete.Smesher, app.signers[0].NodeID().Bytes()) && + bytes.Equal(completeEv.PostComplete.Challenge, shared.ZeroChallenge) + }, + 3: func(ev pb.IsEventDetails) bool { + startEv, ok := ev.(*pb.Event_PostStart) + if !ok { + return false + } + return bytes.Equal(startEv.PostStart.Smesher, app.signers[1].NodeID().Bytes()) && + bytes.Equal(startEv.PostStart.Challenge, shared.ZeroChallenge) + }, + 4: func(ev pb.IsEventDetails) bool { + completeEv, ok := ev.(*pb.Event_PostComplete) + if !ok { + return false + } + return bytes.Equal(completeEv.PostComplete.Smesher, app.signers[1].NodeID().Bytes()) && + bytes.Equal(completeEv.PostComplete.Challenge, shared.ZeroChallenge) + }, + 5: func(ev pb.IsEventDetails) bool { + // TODO(mafa): this event happens once for each NodeID, but should probably only happen once for all + _, ok := ev.(*pb.Event_PoetWaitRound) + return ok + }, + 6: func(ev pb.IsEventDetails) bool { + // TODO(mafa): this event happens once for each NodeID, but should probably only happen once for all + _, ok := ev.(*pb.Event_PoetWaitProof) + return ok + }, + 7: func(ev pb.IsEventDetails) bool { + startEv, ok := ev.(*pb.Event_PostStart) + if !ok { + return false + } + return bytes.Equal(startEv.PostStart.Smesher, app.signers[0].NodeID().Bytes()) && + !bytes.Equal(startEv.PostStart.Challenge, shared.ZeroChallenge) + }, + 8: func(ev pb.IsEventDetails) bool { + completeEv, ok := ev.(*pb.Event_PostComplete) + if !ok { + return false + } + return bytes.Equal(completeEv.PostComplete.Smesher, app.signers[0].NodeID().Bytes()) && + !bytes.Equal(completeEv.PostComplete.Challenge, shared.ZeroChallenge) + }, + 9: func(ev pb.IsEventDetails) bool { + startEv, ok := ev.(*pb.Event_PostStart) + if !ok { + return false + } + return bytes.Equal(startEv.PostStart.Smesher, app.signers[1].NodeID().Bytes()) && + !bytes.Equal(startEv.PostStart.Challenge, shared.ZeroChallenge) + }, + 10: func(ev pb.IsEventDetails) bool { + completeEv, ok := ev.(*pb.Event_PostComplete) + if !ok { + return false + } + return bytes.Equal(completeEv.PostComplete.Smesher, app.signers[1].NodeID().Bytes()) && + !bytes.Equal(completeEv.PostComplete.Challenge, shared.ZeroChallenge) + }, + 11: func(ev pb.IsEventDetails) bool { + _, ok := ev.(*pb.Event_AtxPublished) + return ok + }, + 12: func(ev pb.IsEventDetails) bool { + _, ok := ev.(*pb.Event_AtxPublished) + return ok + }, } - for idx, ev := range success { + for { msg, err := stream.Recv() require.NoError(t, err, "stream %d", i) - require.IsType(t, ev, msg.Details, "stream %d, event %d", i, idx) + for idx, matcher := range matchers { + if matcher(msg.Details) { + t.Log("matched event", idx) + delete(matchers, idx) + break + } + } + if len(matchers) == 0 { + break + } } require.NoError(t, stream.CloseSend()) } @@ -1167,6 +1234,13 @@ func TestEmptyExtraData(t *testing.T) { func getTestDefaultConfig(tb testing.TB) *config.Config { cfg := config.MainnetConfig() + types.SetNetworkHRP(cfg.NetworkHRP) + types.SetLayersPerEpoch(cfg.LayersPerEpoch) + + tmp := tb.TempDir() + cfg.DataDirParent = tmp + cfg.FileLock = filepath.Join(tmp, "LOCK") + cfg.LayerDuration = 20 * time.Second // is set to 0 to make sync start immediately when node starts cfg.P2P.MinPeers = 0 @@ -1179,7 +1253,10 @@ func getTestDefaultConfig(tb testing.TB) *config.Config { cfg.SMESHING = config.DefaultSmeshingConfig() cfg.SMESHING.Start = true + cfg.SMESHING.CoinbaseAccount = types.GenerateAddress([]byte{1}).String() + cfg.SMESHING.Opts.DataDir = filepath.Join(tmp, "post") cfg.SMESHING.Opts.NumUnits = cfg.POST.MinNumUnits + 1 + cfg.SMESHING.Opts.Scrypt.N = 2 cfg.SMESHING.Opts.ProviderID.SetUint32(initialization.CPUProviderID()) cfg.HARE3.RoundDuration = 2 @@ -1191,12 +1268,8 @@ func getTestDefaultConfig(tb testing.TB) *config.Config { cfg.Tortoise.Hdist = 5 cfg.Tortoise.Zdist = 5 - cfg.LayerDuration = 20 * time.Second cfg.HareEligibility.ConfidenceParam = 1 cfg.Sync.Interval = 2 * time.Second - tmp := tb.TempDir() - cfg.DataDirParent = tmp - cfg.FileLock = filepath.Join(tmp, "LOCK") cfg.FETCH.RequestTimeout = 10 * time.Second cfg.FETCH.RequestHardTimeout = 20 * time.Second @@ -1204,12 +1277,8 @@ func getTestDefaultConfig(tb testing.TB) *config.Config { cfg.FETCH.BatchTimeout = 5 * time.Second cfg.Beacon = beacon.NodeSimUnitTestConfig() - cfg.Genesis = config.DefaultTestGenesisConfig() - - cfg.POSTService = config.DefaultTestConfig().POSTService - - types.SetLayersPerEpoch(cfg.LayersPerEpoch) + cfg.POSTService = activation.DefaultTestPostServiceConfig() return &cfg } diff --git a/node/test_network.go b/node/test_network.go index 352e048400..39f65d07bf 100644 --- a/node/test_network.go +++ b/node/test_network.go @@ -88,7 +88,6 @@ func NewTestNetwork(t *testing.T, conf config.Config, l log.Log, size int) []*Te defer cancel() for _, a := range apps { a.Cleanup(ctx) - a.eg.Wait() } }) @@ -114,9 +113,8 @@ func NewApp(t *testing.T, conf *config.Config, l log.Log) *App { err := app.Initialize() require.NoError(t, err) - /* Create or load miner identity */ - app.edSgn, err = app.LoadOrCreateEdSigner() - require.NoError(t, err, "could not retrieve identity") + err = app.NewIdentity() + require.NoError(t, err, "could not create identity") return app } diff --git a/proposals/util/util.go b/proposals/util/util.go index d0312da77a..71c77dcf74 100644 --- a/proposals/util/util.go +++ b/proposals/util/util.go @@ -30,14 +30,8 @@ func GetNumEligibleSlots(weight, minWeight, totalWeight uint64, committeeSize, l if totalWeight == 0 { return 0, ErrZeroTotalWeight } - numEligible := weight * uint64( - committeeSize, - ) * uint64( - layersPerEpoch, - ) / max( - minWeight, - totalWeight, - ) // TODO: ensure no overflow + // TODO: numEligible could overflow uint64 if weight is very large + numEligible := weight * uint64(committeeSize) * uint64(layersPerEpoch) / max(minWeight, totalWeight) if numEligible == 0 { numEligible = 1 } diff --git a/systest/tests/checkpoint_test.go b/systest/tests/checkpoint_test.go index cd9d6525bd..f718ac7de6 100644 --- a/systest/tests/checkpoint_test.go +++ b/systest/tests/checkpoint_test.go @@ -33,6 +33,7 @@ func reuseCluster(tctx *testcontext.Context, restoreLayer uint32) (*cluster.Clus } func TestCheckpoint(t *testing.T) { + // TODO(mafa): add new test with multi-smeshing nodes t.Parallel() tctx := testcontext.New(t, testcontext.Labels("sanity")) diff --git a/systest/tests/distributed_post_verification_test.go b/systest/tests/distributed_post_verification_test.go index 9805c976a8..1518818609 100644 --- a/systest/tests/distributed_post_verification_test.go +++ b/systest/tests/distributed_post_verification_test.go @@ -117,7 +117,6 @@ func TestPostMalfeasanceProof(t *testing.T) { // 1. Initialize postSetupMgr, err := activation.NewPostSetupManager( - signer.NodeID(), cfg.POST, logger.Named("post"), datastore.NewCachedDB(sql.InMemory(), log.NewNop()), @@ -135,7 +134,7 @@ func TestPostMalfeasanceProof(t *testing.T) { postSetupMgr, ) require.NoError(t, err) - require.NoError(t, postSupervisor.Start(cfg.SMESHING.Opts)) + require.NoError(t, postSupervisor.Start(cfg.SMESHING.Opts, signer.NodeID())) t.Cleanup(func() { assert.NoError(t, postSupervisor.Stop(false)) }) // 2. create ATX with invalid POST labels diff --git a/systest/tests/smeshing_test.go b/systest/tests/smeshing_test.go index 4609f1da1d..4d16fb9c3e 100644 --- a/systest/tests/smeshing_test.go +++ b/systest/tests/smeshing_test.go @@ -29,6 +29,7 @@ import ( ) func TestSmeshing(t *testing.T) { + // TODO(mafa): add new test with multi-smeshing nodes t.Parallel() tctx := testcontext.New(t, testcontext.Labels("sanity")) diff --git a/tortoise/algorithm.go b/tortoise/algorithm.go index c09543b83b..0a544cc8ce 100644 --- a/tortoise/algorithm.go +++ b/tortoise/algorithm.go @@ -183,9 +183,9 @@ func (t *Tortoise) OnWeakCoin(lid types.LayerID, coin bool) { } } -// OnMalfeasance registers node id as malfeasent. +// OnMalfeasance registers node id as malfeasant. // - ballots from this id will have zero weight -// - atxs - will not be counted towards global/local threhsolds +// - atxs - will not be counted towards global/local thresholds // If node registers equivocating ballot/atx it should // call OnMalfeasance before storing ballot/atx. func (t *Tortoise) OnMalfeasance(id types.NodeID) {