From 9a76d396d7fbc3c9a3a7b2ca9d14fde8beae2e15 Mon Sep 17 00:00:00 2001 From: Edward Mack Date: Fri, 13 Aug 2021 16:31:35 -0400 Subject: [PATCH] fix(rpc/subscription): subscribe runtime version notify when version changes (#1686) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add code for runtime version changed fix * fix return type * update mock in Makefile * fix core_api mock * implement notifications for runtime updates * lint * implement unregister runtime version listener * make mocks and lint * refactor runtime subscription name, change error handeling * remove unneeded select case * lint, refactor GetID to GetChannelID * lint * add sync.WaitGroup to notify Runtime Updated * add test for Register UnRegister Runtime Update Channel * add check for id * add register panic test * add test to produce panic * generate id as uuid, add locks, delete from map * update runtime updated channel ids to use uint32 * change logging to debug * remove comment * move uuid into check loop * update runtime test * add test for notify runtime updated * update runtime tests * move runtime notify from coreAPI to blockStateAPI * fix mocks for tests * refactor state_test * add tests * lint * remove commented code, fix var assignment * fix merge conflict * add check if channel is ok * update unit test * send notification as go routine Co-authored-by: Eclésio Júnior --- dot/core/service_test.go | 7 +- dot/rpc/modules/api.go | 2 + dot/rpc/modules/api_mocks.go | 18 +++- dot/rpc/modules/api_mocks_test.go | 43 ++++++++ dot/rpc/modules/mocks/BlockAPI.go | 37 +++++++ dot/rpc/subscription/listeners.go | 46 ++++++-- dot/rpc/subscription/listeners_test.go | 59 +++++++++++ dot/rpc/subscription/subscription.go | 2 + dot/rpc/subscription/websocket.go | 31 +++++- dot/rpc/websocket_test.go | 2 +- dot/state/block.go | 48 +++++---- dot/state/block_notify.go | 64 +++++++++++ dot/state/block_notify_test.go | 49 ++++++++- go.mod | 1 + lib/runtime/mocks/version.go | 140 +++++++++++++++++++++++++ 15 files changed, 513 insertions(+), 36 deletions(-) create mode 100644 dot/rpc/modules/api_mocks_test.go create mode 100644 lib/runtime/mocks/version.go diff --git a/dot/core/service_test.go b/dot/core/service_test.go index 9d7a13ea63..e5ef9188d8 100644 --- a/dot/core/service_test.go +++ b/dot/core/service_test.go @@ -25,6 +25,8 @@ import ( "testing" "time" + "github.com/ChainSafe/gossamer/dot/core/mocks" + coremocks "github.com/ChainSafe/gossamer/dot/core/mocks" "github.com/ChainSafe/gossamer/dot/network" "github.com/ChainSafe/gossamer/dot/state" "github.com/ChainSafe/gossamer/dot/sync" @@ -33,6 +35,7 @@ import ( "github.com/ChainSafe/gossamer/lib/keystore" "github.com/ChainSafe/gossamer/lib/runtime" "github.com/ChainSafe/gossamer/lib/runtime/extrinsic" + runtimemocks "github.com/ChainSafe/gossamer/lib/runtime/mocks" "github.com/ChainSafe/gossamer/lib/runtime/storage" "github.com/ChainSafe/gossamer/lib/runtime/wasmer" "github.com/ChainSafe/gossamer/lib/transaction" @@ -41,10 +44,6 @@ import ( log "github.com/ChainSafe/log15" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - - "github.com/ChainSafe/gossamer/dot/core/mocks" - coremocks "github.com/ChainSafe/gossamer/dot/core/mocks" - runtimemocks "github.com/ChainSafe/gossamer/lib/runtime/mocks" ) func addTestBlocksToState(t *testing.T, depth int, blockState BlockState) { diff --git a/dot/rpc/modules/api.go b/dot/rpc/modules/api.go index 023b65995e..1184b13251 100644 --- a/dot/rpc/modules/api.go +++ b/dot/rpc/modules/api.go @@ -40,6 +40,8 @@ type BlockAPI interface { RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) UnregisterFinalisedChannel(id byte) SubChain(start, end common.Hash) ([]common.Hash, error) + RegisterRuntimeUpdatedChannel(ch chan<- runtime.Version) (uint32, error) + UnregisterRuntimeUpdatedChannel(id uint32) bool } // NetworkAPI interface for network state methods diff --git a/dot/rpc/modules/api_mocks.go b/dot/rpc/modules/api_mocks.go index b7ebcb21df..fa1d6c732a 100644 --- a/dot/rpc/modules/api_mocks.go +++ b/dot/rpc/modules/api_mocks.go @@ -3,6 +3,7 @@ package modules import ( modulesmocks "github.com/ChainSafe/gossamer/dot/rpc/modules/mocks" "github.com/ChainSafe/gossamer/lib/common" + runtimemocks "github.com/ChainSafe/gossamer/lib/runtime/mocks" "github.com/stretchr/testify/mock" ) @@ -35,6 +36,8 @@ func NewMockBlockAPI() *modulesmocks.BlockAPI { m.On("GetJustification", mock.AnythingOfType("common.Hash")).Return(make([]byte, 10), nil) m.On("HasJustification", mock.AnythingOfType("common.Hash")).Return(true, nil) m.On("SubChain", mock.AnythingOfType("common.Hash"), mock.AnythingOfType("common.Hash")).Return(make([]common.Hash, 0), nil) + m.On("RegisterRuntimeUpdatedChannel", mock.AnythingOfType("chan<- runtime.Version")).Return(uint32(0), nil) + return m } @@ -43,9 +46,22 @@ func NewMockCoreAPI() *modulesmocks.MockCoreAPI { m := new(modulesmocks.MockCoreAPI) m.On("InsertKey", mock.AnythingOfType("crypto.Keypair")) m.On("HasKey", mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return(false, nil) - m.On("GetRuntimeVersion", mock.AnythingOfType("*common.Hash")).Return(nil, nil) + m.On("GetRuntimeVersion", mock.AnythingOfType("*common.Hash")).Return(NewMockVersion(), nil) m.On("IsBlockProducer").Return(false) m.On("HandleSubmittedExtrinsic", mock.AnythingOfType("types.Extrinsic")).Return(nil) m.On("GetMetadata", mock.AnythingOfType("*common.Hash")).Return(nil, nil) return m } + +// NewMockVersion creates and returns an runtime Version interface mock +func NewMockVersion() *runtimemocks.MockVersion { + m := new(runtimemocks.MockVersion) + m.On("SpecName").Return([]byte(`mock-spec`)) + m.On("ImplName").Return(nil) + m.On("AuthoringVersion").Return(uint32(0)) + m.On("SpecVersion").Return(uint32(0)) + m.On("ImplVersion").Return(uint32(0)) + m.On("TransactionVersion").Return(uint32(0)) + m.On("APIItems").Return(nil) + return m +} diff --git a/dot/rpc/modules/api_mocks_test.go b/dot/rpc/modules/api_mocks_test.go new file mode 100644 index 0000000000..1cf5759c42 --- /dev/null +++ b/dot/rpc/modules/api_mocks_test.go @@ -0,0 +1,43 @@ +// Copyright 2019 ChainSafe Systems (ON) Corp. +// This file is part of gossamer. +// +// The gossamer library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The gossamer library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the gossamer library. If not, see . + +package modules + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewMockStorageAPI(t *testing.T) { + m := NewMockStorageAPI() + require.NotNil(t, m) +} + +func TestNewMockBlockAPI(t *testing.T) { + m := NewMockBlockAPI() + require.NotNil(t, m) +} + +func TestNewMockCoreAPI(t *testing.T) { + m := NewMockCoreAPI() + require.NotNil(t, m) +} + +func TestNewMockVersion(t *testing.T) { + m := NewMockVersion() + require.NotNil(t, m) +} diff --git a/dot/rpc/modules/mocks/BlockAPI.go b/dot/rpc/modules/mocks/BlockAPI.go index 982c91ccc7..dd6e6db19b 100644 --- a/dot/rpc/modules/mocks/BlockAPI.go +++ b/dot/rpc/modules/mocks/BlockAPI.go @@ -8,6 +8,8 @@ import ( common "github.com/ChainSafe/gossamer/lib/common" mock "github.com/stretchr/testify/mock" + runtime "github.com/ChainSafe/gossamer/lib/runtime" + types "github.com/ChainSafe/gossamer/dot/types" ) @@ -233,6 +235,27 @@ func (_m *BlockAPI) RegisterImportedChannel(ch chan<- *types.Block) (byte, error return r0, r1 } +// RegisterRuntimeUpdatedChannel provides a mock function with given fields: ch +func (_m *BlockAPI) RegisterRuntimeUpdatedChannel(ch chan<- runtime.Version) (uint32, error) { + ret := _m.Called(ch) + + var r0 uint32 + if rf, ok := ret.Get(0).(func(chan<- runtime.Version) uint32); ok { + r0 = rf(ch) + } else { + r0 = ret.Get(0).(uint32) + } + + var r1 error + if rf, ok := ret.Get(1).(func(chan<- runtime.Version) error); ok { + r1 = rf(ch) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // SubChain provides a mock function with given fields: start, end func (_m *BlockAPI) SubChain(start common.Hash, end common.Hash) ([]common.Hash, error) { ret := _m.Called(start, end) @@ -265,3 +288,17 @@ func (_m *BlockAPI) UnregisterFinalisedChannel(id byte) { func (_m *BlockAPI) UnregisterImportedChannel(id byte) { _m.Called(id) } + +// UnregisterRuntimeUpdatedChannel provides a mock function with given fields: id +func (_m *BlockAPI) UnregisterRuntimeUpdatedChannel(id uint32) bool { + ret := _m.Called(id) + + var r0 bool + if rf, ok := ret.Get(0).(func(uint32) bool); ok { + r0 = rf(id) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} diff --git a/dot/rpc/subscription/listeners.go b/dot/rpc/subscription/listeners.go index 1d7994927c..3aa9287e88 100644 --- a/dot/rpc/subscription/listeners.go +++ b/dot/rpc/subscription/listeners.go @@ -26,6 +26,7 @@ import ( "github.com/ChainSafe/gossamer/dot/state" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/runtime" ) const ( @@ -282,20 +283,26 @@ func (l *ExtrinsicSubmitListener) Stop() error { // RuntimeVersionListener to handle listening for Runtime Version type RuntimeVersionListener struct { - wsconn *WSConn - subID uint32 + wsconn WSConnAPI + subID uint32 + runtimeUpdate chan runtime.Version + channelID uint32 + coreAPI modules.CoreAPI +} + +// VersionListener interface defining methods that version listener must implement +type VersionListener interface { + GetChannelID() uint32 } // Listen implementation of Listen interface to listen for runtime version changes func (l *RuntimeVersionListener) Listen() { // This sends current runtime version once when subscription is created - // TODO (ed) add logic to send updates when runtime version changes - rtVersion, err := l.wsconn.CoreAPI.GetRuntimeVersion(nil) + rtVersion, err := l.coreAPI.GetRuntimeVersion(nil) if err != nil { return } ver := modules.StateRuntimeVersionResponse{} - ver.SpecName = string(rtVersion.SpecName()) ver.ImplName = string(rtVersion.ImplName()) ver.AuthoringVersion = rtVersion.AuthoringVersion() @@ -304,7 +311,34 @@ func (l *RuntimeVersionListener) Listen() { ver.TransactionVersion = rtVersion.TransactionVersion() ver.Apis = modules.ConvertAPIs(rtVersion.APIItems()) - l.wsconn.safeSend(newSubscriptionResponse(stateRuntimeVersionMethod, l.subID, ver)) + go l.wsconn.safeSend(newSubscriptionResponse(stateRuntimeVersionMethod, l.subID, ver)) + + // listen for runtime updates + go func() { + for { + info, ok := <-l.runtimeUpdate + if !ok { + return + } + + ver := modules.StateRuntimeVersionResponse{} + + ver.SpecName = string(info.SpecName()) + ver.ImplName = string(info.ImplName()) + ver.AuthoringVersion = info.AuthoringVersion() + ver.SpecVersion = info.SpecVersion() + ver.ImplVersion = info.ImplVersion() + ver.TransactionVersion = info.TransactionVersion() + ver.Apis = modules.ConvertAPIs(info.APIItems()) + + l.wsconn.safeSend(newSubscriptionResponse(stateRuntimeVersionMethod, l.subID, ver)) + } + }() +} + +// GetChannelID function that returns listener's channel ID +func (l *RuntimeVersionListener) GetChannelID() uint32 { + return l.channelID } // Stop to runtimeVersionListener not implemented yet because the listener diff --git a/dot/rpc/subscription/listeners_test.go b/dot/rpc/subscription/listeners_test.go index 7be7580550..4dc89479b2 100644 --- a/dot/rpc/subscription/listeners_test.go +++ b/dot/rpc/subscription/listeners_test.go @@ -19,10 +19,12 @@ package subscription import ( "encoding/json" "fmt" + "io/ioutil" "log" "math/big" "net/http" "net/http/httptest" + "path/filepath" "strings" "testing" "time" @@ -33,6 +35,8 @@ import ( "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/grandpa" + "github.com/ChainSafe/gossamer/lib/runtime" + "github.com/ChainSafe/gossamer/lib/runtime/wasmer" "github.com/gorilla/websocket" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -325,3 +329,58 @@ func setupWSConn(t *testing.T) (*WSConn, *websocket.Conn, func()) { return wskt, ws, cancel } + +func TestRuntimeChannelListener_Listen(t *testing.T) { + notifyChan := make(chan runtime.Version) + mockConnection := &MockWSConnAPI{} + rvl := RuntimeVersionListener{ + wsconn: mockConnection, + subID: 0, + runtimeUpdate: notifyChan, + coreAPI: modules.NewMockCoreAPI(), + } + + expectedInitialVersion := modules.StateRuntimeVersionResponse{ + SpecName: "mock-spec", + Apis: modules.ConvertAPIs(nil), + } + + expectedInitialResponse := newSubcriptionBaseResponseJSON() + expectedInitialResponse.Method = "state_runtimeVersion" + expectedInitialResponse.Params.Result = expectedInitialVersion + + instance := wasmer.NewTestInstance(t, runtime.NODE_RUNTIME) + _, err := runtime.GetRuntimeBlob(runtime.POLKADOT_RUNTIME_FP, runtime.POLKADOT_RUNTIME_URL) + require.NoError(t, err) + fp, err := filepath.Abs(runtime.POLKADOT_RUNTIME_FP) + require.NoError(t, err) + code, err := ioutil.ReadFile(fp) + require.NoError(t, err) + version, err := instance.CheckRuntimeVersion(code) + require.NoError(t, err) + + expectedUpdatedVersion := modules.StateRuntimeVersionResponse{ + SpecName: "polkadot", + ImplName: "parity-polkadot", + AuthoringVersion: 0, + SpecVersion: 25, + ImplVersion: 0, + TransactionVersion: 5, + Apis: modules.ConvertAPIs(version.APIItems()), + } + + expectedUpdateResponse := newSubcriptionBaseResponseJSON() + expectedUpdateResponse.Method = "state_runtimeVersion" + expectedUpdateResponse.Params.Result = expectedUpdatedVersion + + go rvl.Listen() + + //check initial response + time.Sleep(time.Millisecond * 10) + require.Equal(t, expectedInitialResponse, mockConnection.lastMessage) + + // check response after update + notifyChan <- version + time.Sleep(time.Millisecond * 10) + require.Equal(t, expectedUpdateResponse, mockConnection.lastMessage) +} diff --git a/dot/rpc/subscription/subscription.go b/dot/rpc/subscription/subscription.go index e20ed73420..9dfe6cba70 100644 --- a/dot/rpc/subscription/subscription.go +++ b/dot/rpc/subscription/subscription.go @@ -47,6 +47,8 @@ func (c *WSConn) getUnsubListener(method string, params interface{}) (unsubListe switch method { case "state_unsubscribeStorage": unsub = c.unsubscribeStorageListener + case "state_unsubscribeRuntimeVersion": + unsub = c.unsubscribeRuntimeVersionListener case "grandpa_unsubscribeJustifications": unsub = c.unsubscribeGrandpaJustificationListener default: diff --git a/dot/rpc/subscription/websocket.go b/dot/rpc/subscription/websocket.go index 3d351b6c66..f53e51f952 100644 --- a/dot/rpc/subscription/websocket.go +++ b/dot/rpc/subscription/websocket.go @@ -32,6 +32,7 @@ import ( "github.com/ChainSafe/gossamer/dot/state" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/runtime" log "github.com/ChainSafe/log15" "github.com/gorilla/websocket" ) @@ -350,17 +351,26 @@ func (c *WSConn) initExtrinsicWatch(reqID float64, params interface{}) (Listener } func (c *WSConn) initRuntimeVersionListener(reqID float64, _ interface{}) (Listener, error) { - rvl := &RuntimeVersionListener{ - wsconn: c, - } - if c.CoreAPI == nil { c.safeSendError(reqID, nil, "error CoreAPI not set") return nil, fmt.Errorf("error CoreAPI not set") } + rvl := &RuntimeVersionListener{ + wsconn: c, + runtimeUpdate: make(chan runtime.Version), + coreAPI: c.CoreAPI, + } + + chanID, err := c.BlockAPI.RegisterRuntimeUpdatedChannel(rvl.runtimeUpdate) + if err != nil { + return nil, err + } + c.mu.Lock() + rvl.channelID = chanID + c.qtyListeners++ rvl.subID = atomic.AddUint32(&c.qtyListeners, 1) c.Subscriptions[rvl.subID] = rvl @@ -371,6 +381,19 @@ func (c *WSConn) initRuntimeVersionListener(reqID float64, _ interface{}) (Liste return rvl, nil } +func (c *WSConn) unsubscribeRuntimeVersionListener(reqID float64, l Listener, _ interface{}) { + observer, ok := l.(VersionListener) + if !ok { + initRes := newBooleanResponseJSON(false, reqID) + c.safeSend(initRes) + return + } + id := observer.GetChannelID() + + res := c.BlockAPI.UnregisterRuntimeUpdatedChannel(id) + c.safeSend(newBooleanResponseJSON(res, reqID)) +} + func (c *WSConn) initGrandpaJustificationListener(reqID float64, _ interface{}) (Listener, error) { if c.BlockAPI == nil { c.safeSendError(reqID, nil, "error BlockAPI not set") diff --git a/dot/rpc/websocket_test.go b/dot/rpc/websocket_test.go index 25cff1dc96..8520b6ad8e 100644 --- a/dot/rpc/websocket_test.go +++ b/dot/rpc/websocket_test.go @@ -43,7 +43,7 @@ var testCalls = []struct { {[]byte(`{"jsonrpc":"2.0","method":"state_subscribeStorage","params":[],"id":4}`), []byte(`{"jsonrpc":"2.0","result":2,"id":4}` + "\n")}, {[]byte(`{"jsonrpc":"2.0","method":"chain_subscribeFinalizedHeads","params":[],"id":5}`), []byte(`{"jsonrpc":"2.0","result":3,"id":5}` + "\n")}, {[]byte(`{"jsonrpc":"2.0","method":"author_submitAndWatchExtrinsic","params":["0x010203"],"id":6}`), []byte("{\"jsonrpc\":\"2.0\",\"error\":{\"code\":null,\"message\":\"Failed to call the `TaggedTransactionQueue_validate_transaction` exported function.\"},\"id\":6}\n")}, - {[]byte(`{"jsonrpc":"2.0","method":"state_subscribeRuntimeVersion","params":[],"id":7}`), []byte("{\"jsonrpc\":\"2.0\",\"result\":5,\"id\":7}\n")}, + {[]byte(`{"jsonrpc":"2.0","method":"state_subscribeRuntimeVersion","params":[],"id":7}`), []byte("{\"jsonrpc\":\"2.0\",\"result\":6,\"id\":7}\n")}, } func TestHTTPServer_ServeHTTP(t *testing.T) { diff --git a/dot/state/block.go b/dot/state/block.go index 1d0aa41bb5..26a5114d5c 100644 --- a/dot/state/block.go +++ b/dot/state/block.go @@ -50,12 +50,14 @@ type BlockState struct { lastFinalised common.Hash // block notifiers - imported map[byte]chan<- *types.Block - finalised map[byte]chan<- *types.FinalisationInfo - importedLock sync.RWMutex - finalisedLock sync.RWMutex - importedBytePool *common.BytePool - finalisedBytePool *common.BytePool + imported map[byte]chan<- *types.Block + finalised map[byte]chan<- *types.FinalisationInfo + importedLock sync.RWMutex + finalisedLock sync.RWMutex + importedBytePool *common.BytePool + finalisedBytePool *common.BytePool + runtimeUpdateSubscriptionsLock sync.RWMutex + runtimeUpdateSubscriptions map[uint32]chan<- runtime.Version pruneKeyCh chan *types.Header } @@ -67,13 +69,14 @@ func NewBlockState(db chaindb.Database, bt *blocktree.BlockTree) (*BlockState, e } bs := &BlockState{ - bt: bt, - dbPath: db.Path(), - baseState: NewBaseState(db), - db: chaindb.NewTable(db, blockPrefix), - imported: make(map[byte]chan<- *types.Block), - finalised: make(map[byte]chan<- *types.FinalisationInfo), - pruneKeyCh: make(chan *types.Header, pruneKeyBufferSize), + bt: bt, + dbPath: db.Path(), + baseState: NewBaseState(db), + db: chaindb.NewTable(db, blockPrefix), + imported: make(map[byte]chan<- *types.Block), + finalised: make(map[byte]chan<- *types.FinalisationInfo), + pruneKeyCh: make(chan *types.Header, pruneKeyBufferSize), + runtimeUpdateSubscriptions: make(map[uint32]chan<- runtime.Version), } genesisBlock, err := bs.GetBlockByNumber(big.NewInt(0)) @@ -95,12 +98,13 @@ func NewBlockState(db chaindb.Database, bt *blocktree.BlockTree) (*BlockState, e // NewBlockStateFromGenesis initialises a BlockState from a genesis header, saving it to the database located at basePath func NewBlockStateFromGenesis(db chaindb.Database, header *types.Header) (*BlockState, error) { bs := &BlockState{ - bt: blocktree.NewBlockTreeFromRoot(header, db), - baseState: NewBaseState(db), - db: chaindb.NewTable(db, blockPrefix), - imported: make(map[byte]chan<- *types.Block), - finalised: make(map[byte]chan<- *types.FinalisationInfo), - pruneKeyCh: make(chan *types.Header, pruneKeyBufferSize), + bt: blocktree.NewBlockTreeFromRoot(header, db), + baseState: NewBaseState(db), + db: chaindb.NewTable(db, blockPrefix), + imported: make(map[byte]chan<- *types.Block), + finalised: make(map[byte]chan<- *types.FinalisationInfo), + pruneKeyCh: make(chan *types.Header, pruneKeyBufferSize), + runtimeUpdateSubscriptions: make(map[uint32]chan<- runtime.Version), } if err := bs.setArrivalTime(header.Hash(), time.Now()); err != nil { @@ -660,6 +664,7 @@ func (bs *BlockState) HandleRuntimeChanges(newState *rtstorage.TrieState, rt run } codeSubBlockHash := bs.baseState.LoadCodeSubstitutedBlockHash() + if !codeSubBlockHash.Equal(common.Hash{}) { newVersion, err := rt.CheckRuntimeVersion(code) //nolint if err != nil { @@ -705,6 +710,11 @@ func (bs *BlockState) HandleRuntimeChanges(newState *rtstorage.TrieState, rt run return fmt.Errorf("failed to update code substituted block hash: %w", err) } + newVersion, err := rt.Version() + if err != nil { + return fmt.Errorf("failed to retrieve runtime version: %w", err) + } + go bs.notifyRuntimeUpdated(newVersion) return nil } diff --git a/dot/state/block_notify.go b/dot/state/block_notify.go index 8c8c7e7c77..fc9828d41b 100644 --- a/dot/state/block_notify.go +++ b/dot/state/block_notify.go @@ -17,8 +17,13 @@ package state import ( + "errors" + "sync" + "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/runtime" + "github.com/google/uuid" ) // RegisterImportedChannel registers a channel for block notification upon block import. @@ -132,3 +137,62 @@ func (bs *BlockState) notifyFinalized(hash common.Hash, round, setID uint64) { }(ch) } } + +func (bs *BlockState) notifyRuntimeUpdated(version runtime.Version) { + bs.runtimeUpdateSubscriptionsLock.RLock() + defer bs.runtimeUpdateSubscriptionsLock.RUnlock() + + if len(bs.runtimeUpdateSubscriptions) == 0 { + return + } + + logger.Debug("notifying runtime updated chans...", "chans", bs.runtimeUpdateSubscriptions) + var wg sync.WaitGroup + wg.Add(len(bs.runtimeUpdateSubscriptions)) + for _, ch := range bs.runtimeUpdateSubscriptions { + go func(ch chan<- runtime.Version) { + defer wg.Done() + ch <- version + }(ch) + } + wg.Wait() +} + +// RegisterRuntimeUpdatedChannel function to register chan that is notified when runtime version changes +func (bs *BlockState) RegisterRuntimeUpdatedChannel(ch chan<- runtime.Version) (uint32, error) { + bs.runtimeUpdateSubscriptionsLock.Lock() + defer bs.runtimeUpdateSubscriptionsLock.Unlock() + + if len(bs.runtimeUpdateSubscriptions) == 256 { + return 0, errors.New("channel limit reached") + } + + id := bs.generateID() + + bs.runtimeUpdateSubscriptions[id] = ch + return id, nil +} + +// UnregisterRuntimeUpdatedChannel function to unregister runtime updated channel +func (bs *BlockState) UnregisterRuntimeUpdatedChannel(id uint32) bool { + bs.runtimeUpdateSubscriptionsLock.Lock() + defer bs.runtimeUpdateSubscriptionsLock.Unlock() + ch, ok := bs.runtimeUpdateSubscriptions[id] + if ok { + close(ch) + delete(bs.runtimeUpdateSubscriptions, id) + return true + } + return false +} + +func (bs *BlockState) generateID() uint32 { + var uid uuid.UUID + for { + uid = uuid.New() + if bs.runtimeUpdateSubscriptions[uid.ID()] == nil { + break + } + } + return uid.ID() +} diff --git a/dot/state/block_notify_test.go b/dot/state/block_notify_test.go index 60268d1adb..6135fbaad4 100644 --- a/dot/state/block_notify_test.go +++ b/dot/state/block_notify_test.go @@ -23,7 +23,8 @@ import ( "time" "github.com/ChainSafe/gossamer/dot/types" - + "github.com/ChainSafe/gossamer/lib/runtime" + runtimemocks "github.com/ChainSafe/gossamer/lib/runtime/mocks" "github.com/stretchr/testify/require" ) @@ -153,3 +154,49 @@ func TestFinalizedChannel_Multi(t *testing.T) { bs.UnregisterFinalisedChannel(id) } } + +func TestService_RegisterUnRegisterRuntimeUpdatedChannel(t *testing.T) { + bs := newTestBlockState(t, testGenesisHeader) + ch := make(chan<- runtime.Version) + chID, err := bs.RegisterRuntimeUpdatedChannel(ch) + require.NoError(t, err) + require.NotNil(t, chID) + + res := bs.UnregisterRuntimeUpdatedChannel(chID) + require.True(t, res) +} + +func TestService_RegisterUnRegisterConcurrentCalls(t *testing.T) { + bs := newTestBlockState(t, testGenesisHeader) + + go func() { + for i := 0; i < 100; i++ { + testVer := NewMockVersion(uint32(i)) + go bs.notifyRuntimeUpdated(testVer) + } + }() + + for i := 0; i < 100; i++ { + go func() { + + ch := make(chan<- runtime.Version) + chID, err := bs.RegisterRuntimeUpdatedChannel(ch) + require.NoError(t, err) + unReg := bs.UnregisterRuntimeUpdatedChannel(chID) + require.True(t, unReg) + }() + } +} + +// NewMockVersion creates and returns an runtime Version interface mock +func NewMockVersion(specVer uint32) *runtimemocks.MockVersion { + m := new(runtimemocks.MockVersion) + m.On("SpecName").Return([]byte(`mock-spec`)) + m.On("ImplName").Return(nil) + m.On("AuthoringVersion").Return(uint32(0)) + m.On("SpecVersion").Return(specVer) + m.On("ImplVersion").Return(uint32(0)) + m.On("TransactionVersion").Return(uint32(0)) + m.On("APIItems").Return(nil) + return m +} diff --git a/go.mod b/go.mod index 8c3ed52d72..8af8618492 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/go-playground/validator/v10 v10.4.1 github.com/golang/protobuf v1.4.3 github.com/google/go-cmp v0.5.6 + github.com/google/uuid v1.1.5 github.com/gorilla/mux v1.8.0 github.com/gorilla/rpc v1.2.0 github.com/gorilla/websocket v1.4.2 diff --git a/lib/runtime/mocks/version.go b/lib/runtime/mocks/version.go new file mode 100644 index 0000000000..9d2703dedc --- /dev/null +++ b/lib/runtime/mocks/version.go @@ -0,0 +1,140 @@ +// Code generated by mockery v2.8.0. DO NOT EDIT. + +package mocks + +import ( + runtime "github.com/ChainSafe/gossamer/lib/runtime" + mock "github.com/stretchr/testify/mock" +) + +// MockVersion is an autogenerated mock type for the Version type +type MockVersion struct { + mock.Mock +} + +// APIItems provides a mock function with given fields: +func (_m *MockVersion) APIItems() []*runtime.APIItem { + ret := _m.Called() + + var r0 []*runtime.APIItem + if rf, ok := ret.Get(0).(func() []*runtime.APIItem); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*runtime.APIItem) + } + } + + return r0 +} + +// AuthoringVersion provides a mock function with given fields: +func (_m *MockVersion) AuthoringVersion() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +// Encode provides a mock function with given fields: +func (_m *MockVersion) Encode() ([]byte, error) { + ret := _m.Called() + + var r0 []byte + if rf, ok := ret.Get(0).(func() []byte); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ImplName provides a mock function with given fields: +func (_m *MockVersion) ImplName() []byte { + ret := _m.Called() + + var r0 []byte + if rf, ok := ret.Get(0).(func() []byte); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + return r0 +} + +// ImplVersion provides a mock function with given fields: +func (_m *MockVersion) ImplVersion() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +// SpecName provides a mock function with given fields: +func (_m *MockVersion) SpecName() []byte { + ret := _m.Called() + + var r0 []byte + if rf, ok := ret.Get(0).(func() []byte); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + return r0 +} + +// SpecVersion provides a mock function with given fields: +func (_m *MockVersion) SpecVersion() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +// TransactionVersion provides a mock function with given fields: +func (_m *MockVersion) TransactionVersion() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +}