Skip to content

Commit

Permalink
Merge #3794
Browse files Browse the repository at this point in the history
3794: Galois code review r=coot a=coot

- improve/add documentation
- GR-FIXME: comments and fixme's from Galois Review
- refactors


Co-authored-by: Mark Tullsen <[email protected]>
Co-authored-by: Marcin Szamotulski <[email protected]>
  • Loading branch information
3 people authored Jun 28, 2022
2 parents a95be8e + caba7ad commit 233ae41
Show file tree
Hide file tree
Showing 35 changed files with 655 additions and 678 deletions.
44 changes: 30 additions & 14 deletions network-mux/src/Control/Concurrent/JobPool.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
{-# LANGUAGE ScopedTypeVariables #-}


-- | This module allows the management of a multiple Async jobs which
-- are grouped by an 'Ord group => group' type.
--
module Control.Concurrent.JobPool
( JobPool
, Job (..)
, withJobPool
, forkJob
, readSize
, readGroupSize
, collect
, waitForJob
, cancelGroup
) where

Expand All @@ -19,19 +22,27 @@ import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map

import Control.Exception (SomeAsyncException (..))
import Control.Monad (when)
import Control.Monad (void, when)
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadFork (MonadThread (..))
import Control.Monad.Class.MonadSTM
import Control.Monad.Class.MonadThrow


-- | JobPool allows to submit asynchronous jobs, wait for their completion or
-- cancel. Jobs are grouped, each group can be cancelled separately.
--
data JobPool group m a = JobPool {
jobsVar :: !(TVar m (Map (group, ThreadId m) (Async m ()))),
completionQueue :: !(TQueue m a)
}

data Job group m a = Job (m a) (SomeException -> m a) group String
-- | An asynchronous job which belongs to some group and its exception handler.
--
data Job group m a =
Job (m a) -- ^ job
(SomeException -> m a) -- ^ error handler
group -- ^ job group
String -- ^ thread label

withJobPool :: forall group m a b.
(MonadAsync m, MonadThrow m, MonadLabelledSTM m)
Expand All @@ -54,7 +65,7 @@ withJobPool =
-- condition).
close :: JobPool group m a -> m ()
close JobPool{jobsVar} = do
jobs <- atomically (readTVar jobsVar)
jobs <- readTVarIO jobsVar
mapM_ uninterruptibleCancel jobs

forkJob :: forall group m a.
Expand Down Expand Up @@ -97,20 +108,25 @@ readGroupSize JobPool{jobsVar} group =
. Map.filterWithKey (\(group', _) _ -> group' == group)
<$> readTVar jobsVar

collect :: MonadSTM m => JobPool group m a -> STM m a
collect JobPool{completionQueue} = readTQueue completionQueue
-- | Wait for next successfully completed job. Unlike 'wait' it will not throw
-- if a job errors.
--
waitForJob :: MonadSTM m => JobPool group m a -> STM m a
waitForJob JobPool{completionQueue} = readTQueue completionQueue

-- | Cancel all threads in a given group. Blocks until all threads terminated.
--
cancelGroup :: ( MonadAsync m
, Eq group
)
=> JobPool group m a -> group -> m ()
cancelGroup JobPool { jobsVar } group = do
jobs <- atomically (readTVar jobsVar)
_ <- Map.traverseWithKey (\(group', _) thread ->
when (group' == group) $
cancel thread
)
jobs
return ()
jobs <- readTVarIO jobsVar
void $ Map.traverseWithKey
(\(group', _) thread ->
when (group' == group) $
cancel thread
)
jobs


2 changes: 1 addition & 1 deletion network-mux/src/Network/Mux.hs
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ monitor tracer timeout jobpool egressQueue cmdQueue muxStatus =
go !monitorCtx@MonitorCtx { mcOnDemandProtocols } = do
result <- atomically $ runFirstToFinish $
-- wait for a mini-protocol thread to terminate
(FirstToFinish $ EventJobResult <$> JobPool.collect jobpool)
(FirstToFinish $ EventJobResult <$> JobPool.waitForJob jobpool)

-- wait for a new control command
<> (FirstToFinish $ EventControlCmd <$> readTQueue cmdQueue)
Expand Down
5 changes: 2 additions & 3 deletions network-mux/src/Network/Mux/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,14 @@ newtype MiniProtocolNum = MiniProtocolNum Word16
deriving (Eq, Ord, Enum, Ix, Show)

-- | Per Miniprotocol limits
data MiniProtocolLimits =
newtype MiniProtocolLimits =
MiniProtocolLimits {
-- | Limit on the maximum number of bytes that can be queued in the
-- miniprotocol's ingress queue.
--
maximumIngressQueue :: !Int
maximumIngressQueue :: Int
}


-- $interface
--
-- To run a node you will also need a bearer and a way to run a server, see
Expand Down
19 changes: 10 additions & 9 deletions ouroboros-network-framework/demo/connection-manager.hs
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,14 @@ withBidirectionalConnectionManager snocket socket
serverApplication :: LazySTM.TVar m [[Int]]
-> LazySTM.TVar m [[Int]]
-> LazySTM.TVar m [[Int]]
-> Bundle
-> TemperatureBundle
(ConnectionId peerAddr
-> ControlMessageSTM m
-> [MiniProtocol InitiatorResponderMode ByteString m () ()])
serverApplication hotRequestsVar
warmRequestsVar
establishedRequestsVar
= Bundle {
= TemperatureBundle {
withHot = WithHot $ \_ _ ->
[ let miniProtocolNum = Mux.MiniProtocolNum 1
in MiniProtocol {
Expand Down Expand Up @@ -371,12 +371,12 @@ runInitiatorProtocols
=> SingMuxMode muxMode
-> Mux.Mux muxMode m
-> MuxBundle muxMode ByteString m a b
-> m (Maybe (Bundle [a]))
-> m (Maybe (TemperatureBundle [a]))
runInitiatorProtocols
singMuxMode mux
(Bundle (WithHot hotPtcls)
(WithWarm warmPtcls)
(WithEstablished establishedPtcls)) = do
(TemperatureBundle (WithHot hotPtcls)
(WithWarm warmPtcls)
(WithEstablished establishedPtcls)) = do
-- start all protocols
hotSTMs <- traverse runInitiator hotPtcls
warmSTMs <- traverse runInitiator warmPtcls
Expand All @@ -395,9 +395,10 @@ runInitiatorProtocols
([], established)) ->
return
. Just
$ Bundle (WithHot hot)
(WithWarm warm)
(WithEstablished established)
$ TemperatureBundle
(WithHot hot)
(WithWarm warm)
(WithEstablished established)
where
runInitiator :: MiniProtocol muxMode ByteString m a b
-> m (STM m (Either SomeException a))
Expand Down
6 changes: 3 additions & 3 deletions ouroboros-network-framework/src/Ouroboros/Network/Channel.hs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ hoistChannel nat channel = Channel
--
fixedInputChannel :: MonadSTM m => [a] -> m (Channel m a)
fixedInputChannel xs0 = do
v <- atomically $ newTVar xs0
v <- newTVarIO xs0
return Channel {send, recv = recv v}
where
recv v = atomically $ do
Expand Down Expand Up @@ -155,8 +155,8 @@ createConnectedChannels :: MonadSTM m => m (Channel m a, Channel m a)
createConnectedChannels = do
-- Create two TMVars to act as the channel buffer (one for each direction)
-- and use them to make both ends of a bidirectional channel
bufferA <- atomically $ newEmptyTMVar
bufferB <- atomically $ newEmptyTMVar
bufferA <- newEmptyTMVarIO
bufferB <- newEmptyTMVarIO

return (mvarsAsChannel bufferB bufferA,
mvarsAsChannel bufferA bufferB)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}

-- | Implementation of 'ConnectionHandler'
--
Expand Down Expand Up @@ -101,7 +98,7 @@ data Handle (muxMode :: MuxMode) peerAddr bytes m a b =
Handle {
hMux :: !(Mux muxMode m),
hMuxBundle :: !(MuxBundle muxMode bytes m a b),
hControlMessage :: !(Bundle (StrictTVar m ControlMessage))
hControlMessage :: !(TemperatureBundle (StrictTVar m ControlMessage))
}


Expand Down Expand Up @@ -281,7 +278,7 @@ makeConnectionHandler muxTracer singMuxMode
unmask $ do
traceWith tracer (TrHandshakeSuccess versionNumber agreedOptions)
controlMessageBundle
<- (\a b c -> Bundle (WithHot a) (WithWarm b) (WithEstablished c))
<- (\a b c -> TemperatureBundle (WithHot a) (WithWarm b) (WithEstablished c))
<$> newTVarIO Continue
<*> newTVarIO Continue
<*> newTVarIO Continue
Expand Down Expand Up @@ -348,7 +345,7 @@ makeConnectionHandler muxTracer singMuxMode
unmask $ do
traceWith tracer (TrHandshakeSuccess versionNumber agreedOptions)
controlMessageBundle
<- (\a b c -> Bundle (WithHot a) (WithWarm b) (WithEstablished c))
<- (\a b c -> TemperatureBundle (WithHot a) (WithWarm b) (WithEstablished c))
<$> newTVarIO Continue
<*> newTVarIO Continue
<*> newTVarIO Continue
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
Expand Down Expand Up @@ -57,7 +57,7 @@ import Ouroboros.Network.ConnectionId
import Ouroboros.Network.ConnectionManager.Types
import qualified Ouroboros.Network.ConnectionManager.Types as CM
import Ouroboros.Network.InboundGovernor.ControlChannel
(ControlChannel)
(ControlChannel (..))
import qualified Ouroboros.Network.InboundGovernor.ControlChannel as ControlChannel
import Ouroboros.Network.MuxMode
import Ouroboros.Network.Server.RateLimiting
Expand All @@ -75,7 +75,8 @@ data ConnectionManagerArguments handlerTrace socket peerAddr handle handleError

-- | Trace state transitions.
--
cmTrTracer :: Tracer m (TransitionTrace peerAddr (ConnectionState peerAddr handle handleError version m)),
cmTrTracer :: Tracer m (TransitionTrace peerAddr
(ConnectionState peerAddr handle handleError version m)),

-- | Mux trace.
--
Expand Down Expand Up @@ -157,7 +158,7 @@ instance Eq (MutableConnState peerAddr handle handleError version m) where
newtype FreshIdSupply m = FreshIdSupply { getFreshId :: STM m Int }


-- | Create a 'FreshIdSupply' inside and 'STM' monad.
-- | Create a 'FreshIdSupply' inside an 'STM' monad.
--
newFreshIdSupply :: forall m. MonadSTM m
=> Proxy m -> STM m (FreshIdSupply m)
Expand All @@ -184,7 +185,7 @@ newMutableConnState freshIdSupply connState = do


-- | 'ConnectionManager' state: for each peer we keep a 'ConnectionState' in
-- a mutable variable, which reduce congestion on the 'TMVar' which keeps
-- a mutable variable, which reduces congestion on the 'TMVar' which keeps
-- 'ConnectionManagerState'.
--
-- It is important we can lookup by remote @peerAddr@; this way we can find if
Expand Down Expand Up @@ -330,13 +331,13 @@ instance ( Show peerAddr
, show df
]
show (InboundIdleState connId connThread _handle df) =
concat ([ "InboundIdleState "
, show connId
, " "
, show (asyncThreadId connThread)
, " "
, show df
])
concat [ "InboundIdleState "
, show connId
, " "
, show (asyncThreadId connThread)
, " "
, show df
]
show (InboundState connId connThread _handle df) =
concat [ "InboundState "
, show connId
Expand All @@ -357,10 +358,10 @@ instance ( Show peerAddr
, " "
, show (asyncThreadId connThread)
]
++ maybeToList (((' ' :) . show) <$> handleError))
++ maybeToList ((' ' :) . show <$> handleError))
show (TerminatedState handleError) =
concat (["TerminatedState"]
++ maybeToList (((' ' :) . show) <$> handleError))
++ maybeToList ((' ' :) . show <$> handleError))


getConnThread :: ConnectionState peerAddr handle handleError version m
Expand Down Expand Up @@ -411,7 +412,7 @@ isInboundConn TerminatedState {} = False


abstractState :: MaybeUnknown (ConnectionState muxMode peerAddr m a b) -> AbstractState
abstractState = \s -> case s of
abstractState = \case
Unknown -> UnknownConnectionSt
Race s' -> go s'
Known s' -> go s'
Expand Down Expand Up @@ -541,7 +542,7 @@ withConnectionManager
-- ^ Callback which runs in a thread dedicated for a given connection.
-> (handleError -> HandleErrorType)
-- ^ classify 'handleError's
-> InResponderMode muxMode (ControlChannel m (ControlChannel.NewConnection peerAddr handle))
-> InResponderMode muxMode (ControlChannel peerAddr handle m)
-- ^ On outbound duplex connections we need to notify the server about
-- a new connection.
-> (ConnectionManager muxMode socket peerAddr handle handleError m -> m a)
Expand Down Expand Up @@ -871,7 +872,7 @@ withConnectionManager ConnectionManagerArguments {
unmask (threadDelay delay)
`catch` \e ->
case fromException e
of Just (AsyncCancelled) -> do
of Just AsyncCancelled -> do
t' <- getMonotonicTime
forceThreadDelay (delay - t' `diffTime` t)
_ -> throwIO e
Expand Down Expand Up @@ -1242,8 +1243,9 @@ withConnectionManager ConnectionManagerArguments {
Just {} -> do
case inboundGovernorControlChannel of
InResponderMode controlChannel ->
atomically $ ControlChannel.newInboundConnection
controlChannel connId dataFlow handle
atomically $ ControlChannel.writeMessage
controlChannel
(ControlChannel.NewConnection Inbound connId dataFlow handle)
NotInResponderMode -> return ()
return $ Connected connId dataFlow handle

Expand Down Expand Up @@ -1809,8 +1811,9 @@ withConnectionManager ConnectionManagerArguments {
writeTVar connVar connState'
case inboundGovernorControlChannel of
InResponderMode controlChannel ->
ControlChannel.newOutboundConnection
controlChannel connId dataFlow handle
ControlChannel.writeMessage
controlChannel
(ControlChannel.NewConnection Outbound connId dataFlow handle)
NotInResponderMode -> return ()
return (Just $ mkTransition connState connState')
TerminatedState _ ->
Expand Down Expand Up @@ -1947,15 +1950,15 @@ withConnectionManager ConnectionManagerArguments {
-- operation which returns only once the connection is
-- negotiated.
ReservedOutboundState ->
return $
return
( DemoteToColdLocalError
(TrForbiddenOperation peerAddr st)
st
, Nothing
)

UnnegotiatedState _ _ _ ->
return $
return
( DemoteToColdLocalError
(TrForbiddenOperation peerAddr st)
st
Expand Down
Loading

0 comments on commit 233ae41

Please sign in to comment.