Skip to content

Commit

Permalink
Constant propagation in JuvixReg (#2833)
Browse files Browse the repository at this point in the history
* Closes #2702 
* For this to give any improvement, we need to run dead code elimination
afterwards (#2827).

Depends on:
* #2828
  • Loading branch information
lukaszcz authored Jun 21, 2024
1 parent af758cc commit 1410b63
Show file tree
Hide file tree
Showing 18 changed files with 268 additions and 28 deletions.
21 changes: 15 additions & 6 deletions src/Juvix/Compiler/Casm/Translation/FromReg.hs
Original file line number Diff line number Diff line change
Expand Up @@ -409,12 +409,21 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI
Left err -> error err
Right c ->
goAssignVar _instrBinopResult (Val $ Imm $ mkConst c)
_ ->
goBinop
x
{ Reg._instrBinopArg1 = _instrBinopArg2,
Reg._instrBinopArg2 = _instrBinopArg1
}
_
| Reg.isCommutative _instrBinopOpcode ->
goBinop
x
{ Reg._instrBinopArg1 = _instrBinopArg2,
Reg._instrBinopArg2 = _instrBinopArg1
}
| otherwise -> do
goAssignAp (Val $ Imm $ mkConst c1)
v2 <- goValue _instrBinopArg2
case _instrBinopArg2 of
Reg.CRef {} -> do
goBinop' _instrBinopOpcode _instrBinopResult (MemRef Ap (-2)) v2
_ -> do
goBinop' _instrBinopOpcode _instrBinopResult (MemRef Ap (-1)) v2
Reg.CRef ctr1 -> do
v1 <- mkLoad ctr1
goAssignAp v1
Expand Down
4 changes: 3 additions & 1 deletion src/Juvix/Compiler/Reg/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ data TransformationId
| SSA
| InitBranchVars
| CopyPropagation
| ConstantPropagation
deriving stock (Data, Bounded, Enum, Show)

data PipelineId
Expand All @@ -27,7 +28,7 @@ toRustTransformations :: [TransformationId]
toRustTransformations = [Cleanup]

toCasmTransformations :: [TransformationId]
toCasmTransformations = [Cleanup, CopyPropagation, SSA]
toCasmTransformations = [Cleanup, CopyPropagation, ConstantPropagation, SSA]

instance TransformationId' TransformationId where
transformationText :: TransformationId -> Text
Expand All @@ -37,6 +38,7 @@ instance TransformationId' TransformationId where
SSA -> strSSA
InitBranchVars -> strInitBranchVars
CopyPropagation -> strCopyPropagation
ConstantPropagation -> strConstantPropagation

instance PipelineId' TransformationId PipelineId where
pipelineText :: PipelineId -> Text
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Reg/Data/TransformationId/Strings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ strInitBranchVars = "init-branch-vars"

strCopyPropagation :: Text
strCopyPropagation = "copy-propagation"

strConstantPropagation :: Text
strConstantPropagation = "constant-propagation"
24 changes: 16 additions & 8 deletions src/Juvix/Compiler/Reg/Extra/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ setResultVar instr vref = case instr of
CallClosures x -> CallClosures $ set instrCallClosuresResult vref x
_ -> impossible

overValueRefs :: (VarRef -> VarRef) -> Instruction -> Instruction
overValueRefs f = \case
overValueRefs' :: (VarRef -> Value) -> Instruction -> Instruction
overValueRefs' f = \case
Binop x -> Binop $ goBinop x
Unop x -> Unop $ goUnop x
Cairo x -> Cairo $ goCairo x
Expand All @@ -51,14 +51,19 @@ overValueRefs f = \case
Nop -> Nop
Block x -> Block $ goBlock x
where
fromVarRef :: Value -> VarRef
fromVarRef = \case
VRef r -> r
_ -> impossible

goConstrField :: ConstrField -> ConstrField
goConstrField = over constrFieldRef f
goConstrField = over constrFieldRef (fromVarRef . f)

goValue :: Value -> Value
goValue = \case
ValConst c -> ValConst c
CRef x -> CRef $ goConstrField x
VRef x -> VRef $ f x
VRef x -> f x

goBinop :: InstrBinop -> InstrBinop
goBinop InstrBinop {..} =
Expand Down Expand Up @@ -86,15 +91,15 @@ overValueRefs f = \case
goExtendClosure :: InstrExtendClosure -> InstrExtendClosure
goExtendClosure InstrExtendClosure {..} =
InstrExtendClosure
{ _instrExtendClosureValue = f _instrExtendClosureValue,
{ _instrExtendClosureValue = fromVarRef (f _instrExtendClosureValue),
_instrExtendClosureArgs = map goValue _instrExtendClosureArgs,
..
}

goCallType :: CallType -> CallType
goCallType = \case
CallFun sym -> CallFun sym
CallClosure cl -> CallClosure (f cl)
CallClosure cl -> CallClosure (fromVarRef (f cl))

goCall :: InstrCall -> InstrCall
goCall InstrCall {..} =
Expand All @@ -108,7 +113,7 @@ overValueRefs f = \case
goCallClosures InstrCallClosures {..} =
InstrCallClosures
{ _instrCallClosuresArgs = map goValue _instrCallClosuresArgs,
_instrCallClosuresValue = f _instrCallClosuresValue,
_instrCallClosuresValue = fromVarRef (f _instrCallClosuresValue),
..
}

Expand All @@ -123,7 +128,7 @@ overValueRefs f = \case
goTailCallClosures :: InstrTailCallClosures -> InstrTailCallClosures
goTailCallClosures InstrTailCallClosures {..} =
InstrTailCallClosures
{ _instrTailCallClosuresValue = f _instrTailCallClosuresValue,
{ _instrTailCallClosuresValue = fromVarRef (f _instrTailCallClosuresValue),
_instrTailCallClosuresArgs = map goValue _instrTailCallClosuresArgs,
..
}
Expand All @@ -149,6 +154,9 @@ overValueRefs f = \case
goBlock :: InstrBlock -> InstrBlock
goBlock x = x

overValueRefs :: (VarRef -> VarRef) -> Instruction -> Instruction
overValueRefs f = overValueRefs' (VRef . f)

updateLiveVars' :: (VarRef -> Maybe VarRef) -> Instruction -> Instruction
updateLiveVars' f = \case
Prealloc x -> Prealloc $ over instrPreallocLiveVars (mapMaybe f) x
Expand Down
5 changes: 4 additions & 1 deletion src/Juvix/Compiler/Reg/Extra/Recursors.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import Data.Functor.Identity
import Juvix.Compiler.Reg.Language

data ForwardRecursorSig m c = ForwardRecursorSig
{ _forwardFun :: Instruction -> c -> m (c, Instruction),
{ -- `_forwardFun` is always called first
_forwardFun :: Instruction -> c -> m (c, Instruction),
-- `_forwardCombine` is called if the result of applying `_forwardFun` is
-- `Branch` or `Case`
_forwardCombine :: Instruction -> NonEmpty c -> (c, Instruction)
}

Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Reg/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ where
import Juvix.Compiler.Reg.Data.TransformationId
import Juvix.Compiler.Reg.Transformation.Base
import Juvix.Compiler.Reg.Transformation.Cleanup
import Juvix.Compiler.Reg.Transformation.ConstantPropagation (constantPropagate)
import Juvix.Compiler.Reg.Transformation.CopyPropagation
import Juvix.Compiler.Reg.Transformation.IdentityTrans
import Juvix.Compiler.Reg.Transformation.InitBranchVars
Expand All @@ -23,3 +24,4 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
SSA -> return . computeSSA
InitBranchVars -> return . initBranchVars
CopyPropagation -> return . copyPropagate
ConstantPropagation -> return . constantPropagate
65 changes: 65 additions & 0 deletions src/Juvix/Compiler/Reg/Transformation/ConstantPropagation.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
module Juvix.Compiler.Reg.Transformation.ConstantPropagation where

import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Reg.Extra
import Juvix.Compiler.Reg.Transformation.Base
import Juvix.Compiler.Tree.Evaluator.Builtins

type VarMap = HashMap VarRef Constant

constantPropagateFunction :: Code -> Code
constantPropagateFunction =
snd
. runIdentity
. recurseF
ForwardRecursorSig
{ _forwardFun = \i acc -> return (go i acc),
_forwardCombine = combine
}
mempty
where
go :: Instruction -> VarMap -> (VarMap, Instruction)
go instr mpv = case instr' of
Assign InstrAssign {..}
| ValConst c <- _instrAssignValue ->
(HashMap.insert _instrAssignResult c mpv', instr')
Binop InstrBinop {..}
| ValConst c1 <- _instrBinopArg1,
ValConst c2 <- _instrBinopArg2 ->
case evalBinop' _instrBinopOpcode c1 c2 of
Left _ -> (mpv', instr')
Right c ->
( HashMap.insert _instrBinopResult c mpv',
Assign
InstrAssign
{ _instrAssignResult = _instrBinopResult,
_instrAssignValue = ValConst c
}
)
_ ->
(mpv', instr')
where
instr' = overValueRefs' (adjustVarRef mpv) instr
mpv' = maybe mpv (`HashMap.delete` mpv) (getResultVar instr)

adjustVarRef :: VarMap -> VarRef -> Value
adjustVarRef mpv vref@VarRef {..} = case _varRefGroup of
VarGroupArgs -> VRef vref
VarGroupLocal -> maybe (VRef vref) ValConst (HashMap.lookup vref mpv)

combine :: Instruction -> NonEmpty VarMap -> (VarMap, Instruction)
combine instr mpvs = case instr of
Branch InstrBranch {..}
| ValConst (ConstBool True) <- _instrBranchValue ->
(mpv1, Block $ InstrBlock _instrBranchTrue)
| ValConst (ConstBool False) <- _instrBranchValue ->
(mpv2, Block $ InstrBlock _instrBranchFalse)
where
(mpv1, mpv2) = case mpvs of
mpv1' :| [mpv2'] -> (mpv1', mpv2')
_ -> impossible
_ ->
(combineMaps mpvs, instr)

constantPropagate :: InfoTable -> InfoTable
constantPropagate = mapT (const constantPropagateFunction)
9 changes: 1 addition & 8 deletions src/Juvix/Compiler/Reg/Transformation/CopyPropagation.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module Juvix.Compiler.Reg.Transformation.CopyPropagation where

import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Juvix.Compiler.Reg.Extra
import Juvix.Compiler.Reg.Transformation.Base

Expand Down Expand Up @@ -40,13 +39,7 @@ copyPropagateFunction =
combine :: Instruction -> NonEmpty VarMap -> (VarMap, Instruction)
combine instr mpvs = (mpv, instr')
where
mpv' :| mpvs' = fmap HashMap.toList mpvs
mpv =
HashMap.fromList
. HashSet.toList
. foldr (HashSet.intersection . HashSet.fromList) (HashSet.fromList mpv')
$ mpvs'

mpv = combineMaps mpvs
instr' = case instr of
Branch x -> Branch $ over instrBranchOutVar (fmap (adjustVarRef mpv)) x
Case x -> Case $ over instrCaseOutVar (fmap (adjustVarRef mpv)) x
Expand Down
4 changes: 3 additions & 1 deletion src/Juvix/Compiler/Tree/Language/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ data Constant
| ConstField FField
| ConstUnit
| ConstVoid
deriving stock (Eq)
deriving stock (Eq, Generic)

instance (Hashable Constant)

-- | MemRefs are references to values stored in memory.
data MemRef
Expand Down
16 changes: 16 additions & 0 deletions src/Juvix/Compiler/Tree/Language/Builtins.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,22 @@ data BinaryOp
| OpStrConcat
deriving stock (Eq)

isCommutative :: BinaryOp -> Bool
isCommutative = \case
OpIntAdd -> True
OpIntSub -> False
OpIntMul -> True
OpIntDiv -> False
OpIntMod -> False
OpIntLt -> False
OpIntLe -> False
OpFieldAdd -> True
OpFieldSub -> False
OpFieldMul -> True
OpFieldDiv -> False
OpEq -> True
OpStrConcat -> False

data UnaryOp
= -- | Convert the argument to a string. JV* opcode: `show`.
OpShow
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Data/Field.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ instance Pretty FField where
instance Show FField where
show f = show (fieldToInteger f)

instance Hashable FField where
hashWithSalt salt f = hashWithSalt salt (fieldToInteger f)

fieldAdd :: FField -> FField -> FField
fieldAdd
(FField ((n1 :: Sing (p :: Natural)) :&: (f1 :: PrimeField p)))
Expand Down
9 changes: 9 additions & 0 deletions src/Juvix/Prelude/Base/Foundation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,15 @@ tableNestedInsert ::
HashMap k1 (HashMap k2 a)
tableNestedInsert k1 k2 = tableInsert (HashMap.singleton k2) (HashMap.insert k2) k1

combineMaps :: (Hashable k, Hashable v) => NonEmpty (HashMap k v) -> HashMap k v
combineMaps mps =
HashMap.fromList
. HashSet.toList
. foldr (HashSet.intersection . HashSet.fromList) (HashSet.fromList mpv')
$ mpvs'
where
mpv' :| mpvs' = fmap HashMap.toList mps

--------------------------------------------------------------------------------
-- List
--------------------------------------------------------------------------------
Expand Down
8 changes: 7 additions & 1 deletion test/Casm/Reg/Positive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -222,5 +222,11 @@ tests =
$(mkRelDir ".")
$(mkRelFile "test040.jvr")
$(mkRelFile "out/test040.out")
(Just $(mkRelFile "in/test040.json"))
(Just $(mkRelFile "in/test040.json")),
PosTest
"Test043: Copy & constant propagation"
$(mkRelDir ".")
$(mkRelFile "test043.jvr")
$(mkRelFile "out/test043.out")
Nothing
]
4 changes: 3 additions & 1 deletion test/Reg/Transformation.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Reg.Transformation where

import Base
import Reg.Transformation.ConstantPropagation qualified as ConstantPropagation
import Reg.Transformation.CopyPropagation qualified as CopyPropagation
import Reg.Transformation.IdentityTrans qualified as IdentityTrans
import Reg.Transformation.InitBranchVars qualified as InitBranchVars
Expand All @@ -13,5 +14,6 @@ allTests =
[ IdentityTrans.allTests,
SSA.allTests,
InitBranchVars.allTests,
CopyPropagation.allTests
CopyPropagation.allTests,
ConstantPropagation.allTests
]
21 changes: 21 additions & 0 deletions test/Reg/Transformation/ConstantPropagation.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module Reg.Transformation.ConstantPropagation where

import Base
import Juvix.Compiler.Reg.Transformation
import Reg.Parse.Positive qualified as Parse
import Reg.Transformation.Base

allTests :: TestTree
allTests = testGroup "JuvixReg Constant Propagation" (map liftTest Parse.tests)

pipe :: [TransformationId]
pipe = [ConstantPropagation]

liftTest :: Parse.PosTest -> TestTree
liftTest _testRun =
fromTest
Test
{ _testTransformations = pipe,
_testAssertion = const (return ()),
_testRun
}
2 changes: 1 addition & 1 deletion test/Reg/Transformation/CopyPropagation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Reg.Parse.Positive qualified as Parse
import Reg.Transformation.Base

allTests :: TestTree
allTests = testGroup "Copy Propagation" (map liftTest Parse.tests)
allTests = testGroup "JuvixReg Copy Propagation" (map liftTest Parse.tests)

pipe :: [TransformationId]
pipe = [CopyPropagation]
Expand Down
1 change: 1 addition & 0 deletions tests/Casm/Reg/positive/out/test043.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
79
Loading

0 comments on commit 1410b63

Please sign in to comment.