Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assertions: part 2 #1629

Merged
merged 6 commits into from
Jul 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 67 additions & 34 deletions semantics/executable-spec/src/Control/State/Transition/Extended.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveFunctor #-}
Expand All @@ -13,6 +14,7 @@
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE UndecidableInstances #-}
Expand Down Expand Up @@ -61,7 +63,7 @@ import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict (modify, runStateT)
import qualified Control.Monad.Trans.State.Strict as MonadState
import Data.Data (Data, Typeable)
import Data.Foldable (find, for_, traverse_)
import Data.Foldable (find, traverse_)
import Data.Functor ((<&>))
import Data.Kind (Type)
import Data.Proxy (Proxy (..))
Expand Down Expand Up @@ -421,44 +423,45 @@ applySTSInternal ap goRule ctx =
applySTSInternal' SInitial env =
goRule env `traverse` initialRules
applySTSInternal' STransition jc = do
when (assertPre ap) $
for_ (assertions @s) $
( \case
PreCondition msg cond ->
if cond jc
then
throw $
AssertionViolation
{ avSTS = show $ typeRep (Proxy @s),
avMsg = msg,
avCtx = jc,
avState = Nothing
}
else pure ()
_ -> pure ()
)
!_ <-
when (assertPre ap)
$! sfor_ (assertions @s)
$! ( \case
PreCondition msg cond ->
if not (cond jc)
then
throw
$! AssertionViolation
{ avSTS = show $ typeRep (Proxy @s),
avMsg = msg,
avCtx = jc,
avState = Nothing
}
else pure ()
_ -> pure ()
)
res <- goRule jc `traverse` transitionRules
-- We only care about running postconditions if the state transition was
-- successful.
case (assertPost ap, successOrFirstFailure res) of
!_ <- case (assertPost ap, successOrFirstFailure res) of
(True, (st, [])) ->
for_ (assertions @s) $
( \case
PostCondition msg cond ->
if cond jc st
then
throw $
AssertionViolation
{ avSTS = show $ typeRep (Proxy @s),
avMsg = msg,
avCtx = jc,
avState = Just st
}
else pure ()
_ -> pure ()
)
sfor_ (assertions @s)
$! ( \case
PostCondition msg cond ->
if not (cond jc st)
then
throw
$! AssertionViolation
{ avSTS = show $ typeRep (Proxy @s),
avMsg = msg,
avCtx = jc,
avState = Just st
}
else pure ()
_ -> pure ()
)
_ -> pure ()
pure res
pure $! res

assertPre :: AssertionPolicy -> Bool
assertPre AssertionsAll = True
Expand All @@ -476,3 +479,33 @@ applySTSInternal ap goRule ctx =
-- TODO move this somewhere more sensible
newtype Threshold a = Threshold a
deriving (Eq, Ord, Show, Data, Typeable, NoUnexpectedThunks)

{------------------------------------------------------------------------------
-- Utils
------------------------------------------------------------------------------}

-- | Map each element of a structure to an action, evaluate these actions from
-- left to right, and ignore the results. For a version that doesn't ignore the
-- results see 'Data.Traversable.traverse'.
--
-- This is a strict variant on 'Data.Foldable.traverse_', which evaluates each
-- element of the structure even in a monad which would otherwise allow this to
-- be lazy.
straverse_ :: (Foldable t, Applicative f) => (a -> f b) -> t a -> f ()
straverse_ f = foldr c (pure ())
where
-- See Note [List fusion and continuations in 'c']
c !x !k = (*> k) $! f x
{-# INLINE c #-}

-- | 'sfor_' is 'straverse_' with its arguments flipped. For a version
-- that doesn't ignore the results see 'Data.Traversable.for'.
--
-- >>> sfor_ [1..4] print
-- 1
-- 2
-- 3
-- 4
sfor_ :: (Foldable t, Applicative f) => t a -> (a -> f b) -> f ()
{-# INLINE sfor_ #-}
sfor_ = flip straverse_
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,11 @@ import Hedgehog.Internal.Gen (integral_, runDiscardEffectT)
import Hedgehog.Internal.Tree (NodeT (NodeT), TreeT, nodeChildren, treeValue)

import Control.State.Transition.Extended (BaseM, Environment, IRC (IRC), PredicateFailure, STS, Signal,
State, TRC (TRC), applySTS)
State, TRC (TRC))
import qualified Control.State.Transition.Invalid.Trace as Invalid
import Control.State.Transition.Trace (Trace, TraceOrder (OldestFirst), closure,
extractValues, lastState, mkTrace, traceLength, traceSignals, _traceEnv)
extractValues, lastState, mkTrace, traceLength, traceSignals, _traceEnv
, applySTSTest)
import Hedgehog.Extra.Manual (Manual)
import qualified Hedgehog.Extra.Manual as Manual

Expand Down Expand Up @@ -212,7 +213,7 @@ traceSigGenWithProfile
-> Gen (Trace s)
traceSigGenWithProfile baseEnv aTraceLength profile gen = do
env <- envGen @s (traceLengthValue aTraceLength)
case interpretSTS @s baseEnv $ applySTS @s (IRC env) of
case interpretSTS @s baseEnv $ applySTSTest @s (IRC env) of
-- Hedgehog will give up if the generators fail to produce any valid
-- initial state, hence we don't have a risk of entering an infinite
-- recursion.
Expand Down Expand Up @@ -355,7 +356,7 @@ genTraceOfLength baseEnv aTraceLength profile env st0 aSigGen =
Nothing ->
loop (d - 1) sti acc
Just sig ->
case interpretSTS @s baseEnv $ applySTS @s (TRC(env, sti, sig)) of
case interpretSTS @s baseEnv $ applySTSTest @s (TRC(env, sti, sig)) of
Left _err -> loop (d - 1) sti acc
Right sti' -> loop (d - 1) sti' ((sti', sigTree) : acc)

Expand Down Expand Up @@ -405,7 +406,7 @@ invalidTrace baseEnv maxTraceLength failureProfile = do
let env = _traceEnv tr
st = lastState tr
iSig <- generateSignalWithFailureProportions @s failureProfile env st
let est' = interpretSTS @s baseEnv $ applySTS @s $ TRC (env, st, iSig)
let est' = interpretSTS @s baseEnv $ applySTSTest @s $ TRC (env, st, iSig)
pure $! Invalid.Trace
{ Invalid.validPrefix = tr
, Invalid.signal = iSig
Expand Down Expand Up @@ -655,7 +656,7 @@ onlyValidSignalsAreGeneratedForTrace baseEnv traceGen = property $ do
st' :: State s
st' = lastState tr
sig <- forAll (sigGen @s env st')
let result = interpretSTS @s baseEnv $ applySTS @s (TRC(env, st', sig))
let result = interpretSTS @s baseEnv $ applySTSTest @s (TRC(env, st', sig))
-- TODO: For some reason the result that led to the failure is not shown
-- (even without using tasty, and setting the condition to True === False)
footnoteShow st'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ module Control.State.Transition.Trace
, closure
-- * Miscellaneous utilities
, extractValues
, applySTSTest
)
where

Expand All @@ -51,15 +52,15 @@ import Control.Monad (void)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Reader (MonadReader, ReaderT, ask, runReaderT)
import Data.Data (Data, Typeable, cast, gmapQ)
import Data.Functor ((<&>))
import Data.Maybe (catMaybes)
import Data.Sequence.Strict (StrictSeq ((:<|), Empty))
import qualified Data.Sequence.Strict as StrictSeq
import GHC.Generics (Generic)
import GHC.Stack (HasCallStack)
import Test.Tasty.HUnit (assertFailure, (@?=))

import Control.State.Transition.Extended (BaseM, Environment, PredicateFailure, STS, Signal, State,
TRC (TRC), applySTS)
import Control.State.Transition.Extended hiding (Assertion, trans)

-- Signal and resulting state.
--
Expand Down Expand Up @@ -374,7 +375,7 @@ closure env st0 sigs = mkTrace env st0 <$> loop st0 (reverse sigs) []
where
loop _ [] acc = pure acc
loop sti (sig : sigs') acc =
applySTS @s (TRC(env, sti, sig)) >>= \case
applySTSTest @s (TRC(env, sti, sig)) >>= \case
Left _ -> loop sti sigs' acc
Right sti' -> loop sti' sigs' ((sti', sig) : acc)

Expand Down Expand Up @@ -419,7 +420,7 @@ checkTrace
-> ReaderT (State s -> Signal s -> (Either [[PredicateFailure s]] (State s))) IO (State s)
-> IO ()
checkTrace interp env act =
void $ runReaderT act (\st sig -> interp $ applySTS (TRC(env, st, sig)))
void $ runReaderT act (\st sig -> interp $ applySTSTest (TRC(env, st, sig)))

-- | Extract all the values of a given type.
--
Expand Down Expand Up @@ -480,3 +481,20 @@ sourceSignalTargets trace = zipWith3 SourceSignalTarget states (tail states) sig
where
states = traceStates OldestFirst trace
signals = traceSignals OldestFirst trace

-- | Apply STS checking assertions.
applySTSTest ::
forall s m rtype.
(STS s, RuleTypeRep rtype, m ~ BaseM s) =>
RuleContext rtype s ->
m (Either [[PredicateFailure s]] (State s))
applySTSTest ctx =
applySTSOpts defaultOpts ctx <&> \case
(st, []) -> Right st
(_, pfs) -> Left pfs
where
defaultOpts =
ApplySTSOpts
{ asoAssertions = AssertionsAll,
asoValidation = ValidateAll
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ traceFrom traceEnv maxTraceLength traceGenEnv env st0 = do
loop 0 _ acc = pure $! acc
loop !d sti stSigs = do
sig <- sigGen @sts @traceGenEnv traceGenEnv env sti
case interpretSTS @sts @traceGenEnv traceEnv (STS.applySTS @sts (TRC(env, sti, sig))) of
case interpretSTS @sts @traceGenEnv traceEnv (Trace.applySTSTest @sts (TRC(env, sti, sig))) of
Left _predicateFailures ->
loop (d - 1) sti stSigs
Right sti' ->
Expand Down Expand Up @@ -135,7 +135,7 @@ traceFromInitState
traceFromInitState baseEnv maxTraceLength traceGenEnv genSt0 = do
env <- envGen @sts @traceGenEnv traceGenEnv
res <- fromMaybe (pure . interpretSTS @sts @traceGenEnv baseEnv
. STS.applySTS) genSt0 $ (IRC env)
. Trace.applySTSTest) genSt0 $ (IRC env)

case res of
Left pf -> error $ "Failed to apply the initial rule to the generated environment.\n"
Expand Down Expand Up @@ -265,7 +265,7 @@ onlyValidSignalsAreGeneratedFromInitState baseEnv maxTraceLength traceGenEnv gen
signalIsValid
where
signalIsValid signal =
case interpretSTS @sts @traceGenEnv baseEnv (STS.applySTS @sts (TRC (env, lastState, signal))) of
case interpretSTS @sts @traceGenEnv baseEnv (Trace.applySTSTest @sts (TRC (env, lastState, signal))) of
Left pf -> QuickCheck.counterexample (show (signal, pf)) False
Right _ -> QuickCheck.property True
env = Trace._traceEnv someTrace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ import Cardano.Crypto.DSIGN.Ed25519 as Ed25519
import qualified Cardano.Crypto.Hash as Hash
import qualified Cardano.Crypto.Signing as Byron
import qualified Cardano.Crypto.Wallet as WC
import Cardano.Prelude (panic)
import Cardano.Prelude
( AllowThunksIn (..),
ByteString,
Expand All @@ -57,6 +56,7 @@ import Cardano.Prelude
NoUnexpectedThunks,
Proxy (..),
Word8,
panic,
)
import qualified Data.ByteString.Lazy as LBS
import Data.Maybe (fromMaybe)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,11 @@ ledgerTransition = do

dpstate' <-
trans @(DELEGS crypto) $
TRC (DelegsEnv slot txIx pp tx account, dpstate, StrictSeq.getSeq $ _certs $ _body tx)
TRC
( DelegsEnv slot txIx pp tx account,
dpstate,
StrictSeq.getSeq $ _certs $ _body tx
)

let DPState dstate pstate = dpstate
genDelegs = _genDelegs dstate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,26 @@ module Shelley.Spec.Ledger.STS.Pool
)
where

import Cardano.Binary (FromCBOR (..), ToCBOR (..), decodeListLen, decodeWord, encodeListLen, matchSize)
import Cardano.Binary
( FromCBOR (..),
ToCBOR (..),
decodeListLen,
decodeWord,
encodeListLen,
matchSize,
)
import Cardano.Prelude (NoUnexpectedThunks (..))
import Control.Monad.Trans.Reader (asks)
import Control.State.Transition (STS (..), TRC (..), TransitionRule, failBecause, judgmentContext, liftSTS, (?!))
import Control.State.Transition
( Assertion (..),
STS (..),
TRC (..),
TransitionRule,
failBecause,
judgmentContext,
liftSTS,
(?!),
)
import Data.Kind (Type)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
Expand All @@ -25,13 +41,18 @@ import Data.Word (Word64, Word8)
import GHC.Generics (Generic)
import Shelley.Spec.Ledger.BaseTypes (Globals (..), ShelleyBase, invalidKey)
import Shelley.Spec.Ledger.Coin (Coin)
import Shelley.Spec.Ledger.Core (addpair, haskey, removekey)
import Shelley.Spec.Ledger.Core (addpair, dom, haskey, removekey)
import Shelley.Spec.Ledger.Crypto (Crypto)
import Shelley.Spec.Ledger.Keys (KeyHash (..), KeyRole (..))
import Shelley.Spec.Ledger.LedgerState (PState (..), emptyPState)
import Shelley.Spec.Ledger.PParams (PParams, PParams' (..))
import Shelley.Spec.Ledger.Slot (EpochNo (..), SlotNo, epochInfoEpoch)
import Shelley.Spec.Ledger.TxData (DCert (..), PoolCert (..), PoolParams (..), StakePools (..))
import Shelley.Spec.Ledger.TxData
( DCert (..),
PoolCert (..),
PoolParams (..),
StakePools (..),
)

data POOL (crypto :: Type)

Expand Down Expand Up @@ -66,6 +87,14 @@ instance Typeable crypto => STS (POOL crypto) where

transitionRules = [poolDelegationTransition]

assertions =
[ PreCondition
"_stPools and _pParams must have the same domain"
( \(TRC (_, st, _)) ->
dom (unStakePools $ _stPools st) == dom (_pParams st)
)
]

instance NoUnexpectedThunks (PredicateFailure (POOL crypto))

instance
Expand Down
Loading