diff --git a/hydra-node/exe/hydra-node/Main.hs b/hydra-node/exe/hydra-node/Main.hs index d9415302d29..ccf737df95e 100644 --- a/hydra-node/exe/hydra-node/Main.hs +++ b/hydra-node/exe/hydra-node/Main.hs @@ -10,7 +10,6 @@ import Hydra.Cardano.Api (serialiseToRawBytesHex) import Hydra.Chain (HeadParameters (..)) import Hydra.Chain.Direct (initialChainState, loadChainContext, mkTinyWallet, withDirectChain) import Hydra.Chain.Direct.ScriptRegistry (publishHydraScripts) -import Hydra.Chain.Direct.State (ChainStateAt (..)) import Hydra.Chain.Direct.Util (readKeyPair) import Hydra.HeadLogic ( ClosedState (..), @@ -102,8 +101,8 @@ main = do nodeState <- createNodeState hs ctx <- loadChainContext chainConfig party otherParties hydraScriptsTxId wallet <- mkTinyWallet (contramap DirectChain tracer) chainConfig - let ChainStateAt{recordedAt} = getChainState hs - withDirectChain (contramap DirectChain tracer) chainConfig ctx recordedAt wallet (chainCallback nodeState eq) $ \chain -> do + let chainStateAt = getChainState hs + withDirectChain (contramap DirectChain tracer) chainConfig ctx wallet chainStateAt (chainCallback eq) $ \chain -> do let RunOptions{host, port, peers, nodeId} = opts withNetwork (contramap Network tracer) host port peers nodeId (putEvent eq . NetworkEvent defaultTTL) $ \hn -> do let RunOptions{apiHost, apiPort} = opts diff --git a/hydra-node/src/Hydra/Chain.hs b/hydra-node/src/Hydra/Chain.hs index d5b76320295..0d6c982a4b4 100644 --- a/hydra-node/src/Hydra/Chain.hs +++ b/hydra-node/src/Hydra/Chain.hs @@ -225,7 +225,7 @@ instance -- | A callback indicating receival of a potential Hydra transaction which is Maybe -- observing a relevant 'ChainEvent tx'. -type ChainCallback tx m = (ChainStateType tx -> Maybe (ChainEvent tx)) -> m () +type ChainCallback tx m = ChainEvent tx -> m () -- | A type tying both posting and observing transactions into a single /Component/. type ChainComponent tx m a = ChainCallback tx m -> (Chain tx m -> m a) -> m a diff --git a/hydra-node/src/Hydra/Chain/Direct.hs b/hydra-node/src/Hydra/Chain/Direct.hs index 52373fa2caf..7864dd4ca74 100644 --- a/hydra-node/src/Hydra/Chain/Direct.hs +++ b/hydra-node/src/Hydra/Chain/Direct.hs @@ -29,6 +29,7 @@ import Control.Exception (IOException) import Control.Monad.Class.MonadSTM ( newEmptyTMVar, newTQueueIO, + newTVarIO, putTMVar, readTQueue, takeTMVar, @@ -38,7 +39,6 @@ import Control.Monad.Trans.Except (runExcept) import Control.Tracer (nullTracer) import Hydra.Cardano.Api ( CardanoMode, - ChainPoint, ConsensusMode (CardanoMode), EraHistory (EraHistory), LedgerEra, @@ -78,6 +78,7 @@ import Hydra.Chain.Direct.State ( ChainState (Idle), ChainStateAt (..), ) +import qualified Hydra.Chain.Direct.State as ChainState import Hydra.Chain.Direct.TimeHandle (queryTimeHandle) import Hydra.Chain.Direct.Util ( Block, @@ -210,11 +211,13 @@ withDirectChain :: Tracer IO DirectChainLog -> ChainConfig -> ChainContext -> - -- | Last known point on chain as loaded from persistence. - Maybe ChainPoint -> TinyWallet IO -> + ChainStateAt -> ChainComponent Tx IO a -withDirectChain tracer config ctx persistedPoint wallet callback action = do +withDirectChain tracer config ctx wallet chainStateAt callback action = do + -- Last known point on chain as loaded from persistence. + let persistedPoint = recordedAt chainStateAt + chainStateTVar <- newTVarIO $ ChainState.chainState chainStateAt queue <- newTQueueIO -- Select a chain point from which to start synchronizing chainPoint <- maybe (queryTip networkId nodeSocket) pure $ do @@ -232,7 +235,7 @@ withDirectChain tracer config ctx persistedPoint wallet callback action = do res <- race ( handle onIOException $ do - let handler = chainSyncHandler tracer callback getTimeHandle ctx + let handler = chainSyncHandler tracer callback getTimeHandle ctx chainStateTVar let intersection = toConsensusPointInMode CardanoMode chainPoint let client = ouroborosApplication tracer intersection queue handler wallet withIOManager $ \iocp -> diff --git a/hydra-node/src/Hydra/Chain/Direct/Handlers.hs b/hydra-node/src/Hydra/Chain/Direct/Handlers.hs index c1bd44cb939..b2176738c60 100644 --- a/hydra-node/src/Hydra/Chain/Direct/Handlers.hs +++ b/hydra-node/src/Hydra/Chain/Direct/Handlers.hs @@ -16,7 +16,7 @@ import Cardano.Ledger.Crypto (StandardCrypto) import Cardano.Ledger.Era (SupportsSegWit (fromTxSeq)) import qualified Cardano.Ledger.Shelley.API as Ledger import Cardano.Slotting.Slot (SlotNo (..)) -import Control.Monad.Class.MonadSTM (throwSTM) +import Control.Monad.Class.MonadSTM (throwSTM, writeTVar) import Data.Sequence.Strict (StrictSeq) import Hydra.Cardano.Api ( ChainPoint (..), @@ -190,18 +190,22 @@ chainSyncHandler :: GetTimeHandle m -> -- | Contextual information about our chain connection. ChainContext -> + -- TVar containing the local ChainState + -- TODO try ChainStateAt + TVar m ChainState -> -- | A chain-sync handler to use in a local-chain-sync client. ChainSyncHandler m -chainSyncHandler tracer callback getTimeHandle ctx = +chainSyncHandler tracer callback getTimeHandle ctx chainStateTVar = ChainSyncHandler { onRollBackward , onRollForward } where + onRollBackward :: Point Block -> m () onRollBackward rollbackPoint = do let point = fromConsensusPointInMode CardanoMode rollbackPoint traceWith tracer $ RolledBackward{point} - callback (const . Just $ Rollback $ chainSlotFromPoint point) + callback (Rollback $ chainSlotFromPoint point) onRollForward :: Block -> m () onRollForward blk = do @@ -221,22 +225,21 @@ chainSyncHandler tracer callback getTimeHandle ctx = Left reason -> throwIO TimeConversionException{slotNo, reason} Right utcTime -> - callback (const . Just $ Tick utcTime) + callback (Tick utcTime) - forM_ receivedTxs $ \tx -> - callback $ \ChainStateAt{chainState = cs} -> - case observeSomeTx ctx cs tx of - Nothing -> Nothing - Just (observedTx, cs') -> - Just $ - Observation - { observedTx - , newChainState = - ChainStateAt - { chainState = cs' - , recordedAt = Just point - } - } + forM_ receivedTxs $ \tx -> do + -- TODO: refactor to use modifyTVar instead of read and then write + cs <- atomically $ readTVar chainStateTVar + case observeSomeTx ctx cs tx of + Nothing -> pure () + Just (observedTx, cs') -> do + let newChainState = + ChainStateAt + { chainState = cs' + , recordedAt = Just point + } + atomically $ writeTVar chainStateTVar cs' + callback Observation{observedTx, newChainState} prepareTxToPost :: (MonadSTM m, MonadThrow (STM m)) => diff --git a/hydra-node/src/Hydra/Node.hs b/hydra-node/src/Hydra/Node.hs index d84388851a8..9ce49686799 100644 --- a/hydra-node/src/Hydra/Node.hs +++ b/hydra-node/src/Hydra/Node.hs @@ -38,7 +38,7 @@ import Control.Monad.Class.MonadSTM ( ) import Hydra.API.Server (Server, sendOutput) import Hydra.Cardano.Api (AsType (AsSigningKey, AsVerificationKey)) -import Hydra.Chain (Chain (..), ChainCallback, ChainEvent (..), ChainStateType, IsChainState, PostTxError) +import Hydra.Chain (Chain (..), ChainCallback, ChainStateType, IsChainState, PostTxError) import Hydra.Chain.Direct.Util (readFileTextEnvelopeThrow) import Hydra.Crypto (AsType (AsHydraKey)) import Hydra.HeadLogic ( @@ -48,8 +48,6 @@ import Hydra.HeadLogic ( HeadState (..), Outcome (..), defaultTTL, - getChainState, - setChainState, ) import qualified Hydra.HeadLogic as Logic import Hydra.Ledger (IsTx, Ledger) @@ -262,23 +260,7 @@ createNodeState initialState = do } chainCallback :: - MonadSTM m => - NodeState tx m -> EventQueue m (Event tx) -> ChainCallback tx m -chainCallback NodeState{modifyHeadState} eq cont = do - -- Provide chain state to continuation and update it when we get a newState - -- NOTE: Although we do handle the chain state explictly in the 'HeadLogic', - -- this is required as multiple transactions may be observed and the chain - -- state shall accumulate the state changes coming with those observations. - mEvent <- atomically . modifyHeadState $ \hs -> - case cont $ getChainState hs of - Nothing -> - (Nothing, hs) - Just ev@Observation{newChainState} -> - (Just ev, setChainState newChainState hs) - Just ev -> - (Just ev, hs) - case mEvent of - Nothing -> pure () - Just chainEvent -> putEvent eq $ OnChainEvent{chainEvent} \ No newline at end of file +chainCallback eventQueue chainEvent = do + putEvent eventQueue $ OnChainEvent{chainEvent} \ No newline at end of file diff --git a/hydra-node/test/Hydra/Chain/Direct/HandlersSpec.hs b/hydra-node/test/Hydra/Chain/Direct/HandlersSpec.hs index cfb931a80b4..316135bccd5 100644 --- a/hydra-node/test/Hydra/Chain/Direct/HandlersSpec.hs +++ b/hydra-node/test/Hydra/Chain/Direct/HandlersSpec.hs @@ -16,7 +16,6 @@ import Hydra.Cardano.Api ( toLedgerTx, ) import Hydra.Chain ( - ChainCallback, ChainEvent (..), ChainSlot (..), HeadParameters, @@ -110,10 +109,11 @@ spec = do monadicIO $ do (timeHandle, slot) <- pickBlind genTimeHandleWithSlotPastHorizon blk <- pickBlind $ genBlockAt slot [] - chainContext <- pickBlind arbitrary + chainState <- pickBlind arbitrary + chainStateVar <- run $ newTVarIO chainState let chainSyncCallback = \_cont -> failure "Unexpected callback" - handler = chainSyncHandler nullTracer chainSyncCallback (pure timeHandle) chainContext + handler = chainSyncHandler nullTracer chainSyncCallback (pure timeHandle) chainContext chainStateVar run $ onRollForward handler blk @@ -122,48 +122,34 @@ spec = do prop "yields observed transactions rolling forward" . monadicIO $ do -- Generate a state and related transaction and a block containing it (ctx, st, tx, transition) <- pick genChainStateWithTx - let chainState = ChainStateAt{chainState = st, recordedAt = Nothing} blk <- pickBlind $ genBlockAt 1 [tx] monitor (label $ show transition) - + chainStateVar <- run $ newTVarIO st timeHandle <- pickBlind arbitrary - let callback cont = - -- Give chain state in which we expect the 'tx' to yield an 'Observation'. - case cont chainState of - Nothing -> - -- XXX: We need this to debug as 'failure' (via 'run') does not - -- yield counter examples. - failure . toString $ - unlines - [ "expected continuation to yield an event" - , "transition: " <> show transition - , "chainState: " <> show st - ] - Just Rollback{} -> - failure "rolled back but expected roll forward." - Just Tick{} -> pure () - Just Observation{observedTx} -> - fst <$> observeSomeTx ctx st tx `shouldBe` Just observedTx + let callback = \case + Rollback{} -> + failure "rolled back but expected roll forward." + Tick{} -> pure () + Observation{observedTx} -> + if (fst <$> observeSomeTx ctx st tx) /= Just observedTx + then failure $ show (fst <$> observeSomeTx ctx st tx) <> " /= " <> show (Just observedTx) + else pure () - let handler = chainSyncHandler nullTracer callback (pure timeHandle) ctx + let handler = chainSyncHandler nullTracer callback (pure timeHandle) ctx chainStateVar run $ onRollForward handler blk prop "yields rollback events onRollBackward" . monadicIO $ do - (chainContext, chainState, blocks) <- pickBlind genSequenceOfObservableBlocks + (chainContext, chainStateAt, blocks) <- pickBlind genSequenceOfObservableBlocks (rollbackSlot, rollbackPoint) <- pick $ genRollbackPoint blocks monitor $ label ("Rollback to: " <> show rollbackSlot <> " / " <> show (length blocks)) timeHandle <- pickBlind arbitrary -- Mock callback which keeps the chain state in a tvar - stateVar <- run $ newTVarIO chainState + chainStateVar <- run $ newTVarIO (chainState chainStateAt) rolledBackTo <- run newEmptyTMVarIO - let callback cont = do - cs <- readTVarIO stateVar - case cont cs of - Nothing -> do - failure "expected continuation to yield observation" - Just Tick{} -> pure () - Just (Rollback slot) -> atomically $ putTMVar rolledBackTo slot - Just Observation{newChainState} -> atomically $ writeTVar stateVar newChainState + let callback = \case + Tick{} -> pure () + (Rollback slot) -> atomically $ putTMVar rolledBackTo slot + Observation{newChainState} -> atomically $ writeTVar chainStateVar (chainState newChainState) let handler = chainSyncHandler @@ -171,6 +157,7 @@ spec = do callback (pure timeHandle) chainContext + chainStateVar -- Simulate some chain following run $ mapM_ (onRollForward handler) blocks -- Inject the rollback to somewhere between any of the previous state @@ -188,16 +175,14 @@ spec = do recordEventsHandler :: ChainContext -> ChainStateAt -> GetTimeHandle IO -> IO (ChainSyncHandler IO, IO [ChainEvent Tx]) recordEventsHandler ctx _cs getTimeHandle = do eventsVar <- newTVarIO [] - let handler = chainSyncHandler nullTracer (recordEvents eventsVar) getTimeHandle ctx + chainStateVar <- newTVarIO (chainState _cs) + let handler = chainSyncHandler nullTracer (recordEvents eventsVar) getTimeHandle ctx chainStateVar pure (handler, getEvents eventsVar) where getEvents = readTVarIO - recordEvents :: TVar IO [ChainEvent Tx] -> ChainCallback Tx IO - recordEvents var cont = do - case cont _cs of - Nothing -> pure () - Just e -> atomically $ modifyTVar var (e :) + recordEvents var event = do + atomically $ modifyTVar var (event :) withCounterExample :: [Block] -> TVar IO ChainStateAt -> IO a -> PropertyM IO a withCounterExample blks headState step = do diff --git a/hydra-node/test/Hydra/Model/MockChain.hs b/hydra-node/test/Hydra/Model/MockChain.hs index bb50a510207..e64cb4b7ee8 100644 --- a/hydra-node/test/Hydra/Model/MockChain.hs +++ b/hydra-node/test/Hydra/Model/MockChain.hs @@ -28,10 +28,12 @@ import Hydra.BehaviorSpec ( ConnectToChain (..), ) import Hydra.Chain (Chain (..)) +import Hydra.Chain.Direct (initialChainState) import Hydra.Chain.Direct.Fixture (testNetworkId) import Hydra.Chain.Direct.Handlers (ChainSyncHandler, DirectChainLog, SubmitTx, chainSyncHandler, mkChain, onRollBackward, onRollForward) import Hydra.Chain.Direct.ScriptRegistry (ScriptRegistry (..)) import Hydra.Chain.Direct.State (ChainContext (..), ChainStateAt (..)) +import qualified Hydra.Chain.Direct.State as ChainState import qualified Hydra.Chain.Direct.State as S import Hydra.Chain.Direct.TimeHandle (TimeHandle) import qualified Hydra.Chain.Direct.Util as Util @@ -80,6 +82,7 @@ mockChainAndNetwork :: ContestationPeriod -> m (ConnectToChain Tx m, Async m ()) mockChainAndNetwork tr seedKeys nodes cp = do + chainStateTVar <- newTVarIO $ ChainState.chainState initialChainState queue <- newTQueueIO labelTQueueIO queue "chain-queue" chain <- newTVarIO (0, 0, Empty) @@ -119,8 +122,8 @@ mockChainAndNetwork tr seedKeys nodes cp = do let seedInput = genTxIn `generateWith` 42 nodeState <- createNodeState $ Idle IdleState{chainState} let HydraNode{eq} = node - let callback = chainCallback nodeState eq - let chainHandler = chainSyncHandler tr callback getTimeHandle ctx + let callback = chainCallback eq + let chainHandler = chainSyncHandler tr callback getTimeHandle ctx chainStateTVar let node' = node { hn =