Skip to content

Commit

Permalink
Peephole optimization of Cairo assembly (#2858)
Browse files Browse the repository at this point in the history
* Closes #2703 
* Adds [peephole
optimization](https://en.wikipedia.org/wiki/Peephole_optimization) of
Cairo assembly.
* Adds a transformation framework for the CASM IR.
* Adds `--transforms`, `--run` and `--no-print` options to the `dev casm
read` command.
  • Loading branch information
lukaszcz authored Jun 27, 2024
1 parent 4dcbb00 commit 802d82f
Show file tree
Hide file tree
Showing 18 changed files with 327 additions and 38 deletions.
29 changes: 27 additions & 2 deletions app/Commands/Dev/Casm/Read.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ module Commands.Dev.Casm.Read where

import Commands.Base
import Commands.Dev.Casm.Read.Options
import Juvix.Compiler.Casm.Pretty qualified as Casm
import Juvix.Compiler.Casm.Data.InputInfo qualified as Casm
import Juvix.Compiler.Casm.Extra.LabelInfo qualified as Casm
import Juvix.Compiler.Casm.Interpreter qualified as Casm
import Juvix.Compiler.Casm.Pretty qualified as Casm.Pretty
import Juvix.Compiler.Casm.Transformation qualified as Casm
import Juvix.Compiler.Casm.Translation.FromSource qualified as Casm
import Juvix.Compiler.Casm.Validate qualified as Casm

Expand All @@ -15,7 +19,28 @@ runCommand opts = do
Right (labi, code) ->
case Casm.validate labi code of
Left err -> exitJuvixError (JuvixError err)
Right () -> renderStdOut (Casm.ppProgram code)
Right () -> do
r <-
runError @JuvixError
. runReader Casm.defaultOptions
$ (Casm.applyTransformations (project opts ^. casmReadTransformations) code)
case r of
Left err -> exitJuvixError (JuvixError err)
Right code' -> do
unless (project opts ^. casmReadNoPrint) $
renderStdOut (Casm.Pretty.ppProgram code')
doRun code'
where
file :: AppPath File
file = opts ^. casmReadInputFile

doRun :: Casm.Code -> Sem r ()
doRun code'
| project opts ^. casmReadRun = do
putStrLn "--------------------------------"
putStrLn "| Run |"
putStrLn "--------------------------------"
let labi = Casm.computeLabelInfo code'
inputInfo = Casm.InputInfo mempty
print (Casm.runCode inputInfo labi code')
| otherwise = return ()
11 changes: 9 additions & 2 deletions app/Commands/Dev/Casm/Read/Options.hs
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
module Commands.Dev.Casm.Read.Options where

import CommonOptions
import Juvix.Compiler.Casm.Data.TransformationId

newtype CasmReadOptions = CasmReadOptions
{ _casmReadInputFile :: AppPath File
data CasmReadOptions = CasmReadOptions
{ _casmReadTransformations :: [TransformationId],
_casmReadRun :: Bool,
_casmReadNoPrint :: Bool,
_casmReadInputFile :: AppPath File
}
deriving stock (Data)

makeLenses ''CasmReadOptions

parseCasmReadOptions :: Parser CasmReadOptions
parseCasmReadOptions = do
_casmReadNoPrint <- optReadNoPrint
_casmReadRun <- optReadRun
_casmReadTransformations <- optCasmTransformationIds
_casmReadInputFile <- parseInputFile FileExtCasm
pure CasmReadOptions {..}
12 changes: 2 additions & 10 deletions app/Commands/Dev/Reg/Read/Options.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,8 @@ makeLenses ''RegReadOptions

parseRegReadOptions :: Parser RegReadOptions
parseRegReadOptions = do
_regReadNoPrint <-
switch
( long "no-print"
<> help "Do not print the transformed code"
)
_regReadRun <-
switch
( long "run"
<> help "Run the code after the transformation"
)
_regReadNoPrint <- optReadNoPrint
_regReadRun <- optReadRun
_regReadTransformations <- optRegTransformationIds
_regReadInputFile <- parseInputFile FileExtJuvixReg
pure RegReadOptions {..}
18 changes: 18 additions & 0 deletions app/CommonOptions.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ where
import Control.Exception qualified as GHC
import Data.List.NonEmpty qualified as NonEmpty
import GHC.Conc
import Juvix.Compiler.Casm.Data.TransformationId.Parser qualified as Casm
import Juvix.Compiler.Concrete.Translation.ImportScanner
import Juvix.Compiler.Core.Data.TransformationId.Parser qualified as Core
import Juvix.Compiler.Pipeline.EntryPoint
Expand Down Expand Up @@ -282,6 +283,20 @@ optNoDisambiguate =
<> help "Don't disambiguate the names of bound variables"
)

optReadRun :: Parser Bool
optReadRun =
switch
( long "run"
<> help "Run the code after the transformation"
)

optReadNoPrint :: Parser Bool
optReadNoPrint =
switch
( long "no-print"
<> help "Do not print the transformed code"
)

optTransformationIds :: forall a. (Text -> Either Text [a]) -> (String -> [String]) -> Parser [a]
optTransformationIds parseIds completions =
option
Expand Down Expand Up @@ -317,6 +332,9 @@ optTreeTransformationIds = optTransformationIds Tree.parseTransformations Tree.c
optRegTransformationIds :: Parser [Reg.TransformationId]
optRegTransformationIds = optTransformationIds Reg.parseTransformations Reg.completionsString

optCasmTransformationIds :: Parser [Casm.TransformationId]
optCasmTransformationIds = optTransformationIds Casm.parseTransformations Casm.completionsString

class EntryPointOptions a where
applyOptions :: a -> EntryPoint -> EntryPoint

Expand Down
3 changes: 3 additions & 0 deletions runtime/casm/stdlib.casm
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ juvix_ec_op:
-- [fp - 3]: closure
-- [fp - 4]: n = the number of arguments to extend with
-- [fp - 4 - k]: argument n - k - 1 (reverse order!) (k is 0-based)
-- On return:
-- [ap - 1]: new closure
-- This procedure doesn't accept or return the builtins pointer.
juvix_extend_closure:
-- copy stored args reversing them;
-- to copy the stored args to the new closure
Expand Down
34 changes: 34 additions & 0 deletions src/Juvix/Compiler/Casm/Data/TransformationId.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
module Juvix.Compiler.Casm.Data.TransformationId where

import Juvix.Compiler.Casm.Data.TransformationId.Strings
import Juvix.Compiler.Core.Data.TransformationId.Base
import Juvix.Prelude

data TransformationId
= IdentityTrans
| Peephole
deriving stock (Data, Bounded, Enum, Show)

data PipelineId
= PipelineCairo
deriving stock (Data, Bounded, Enum)

type TransformationLikeId = TransformationLikeId' TransformationId PipelineId

toCairoTransformations :: [TransformationId]
toCairoTransformations = [Peephole]

instance TransformationId' TransformationId where
transformationText :: TransformationId -> Text
transformationText = \case
IdentityTrans -> strIdentity
Peephole -> strPeephole

instance PipelineId' TransformationId PipelineId where
pipelineText :: PipelineId -> Text
pipelineText = \case
PipelineCairo -> strCairoPipeline

pipeline :: PipelineId -> [TransformationId]
pipeline = \case
PipelineCairo -> toCairoTransformations
14 changes: 14 additions & 0 deletions src/Juvix/Compiler/Casm/Data/TransformationId/Parser.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module Juvix.Compiler.Casm.Data.TransformationId.Parser (parseTransformations, TransformationId (..), completions, completionsString) where

import Juvix.Compiler.Casm.Data.TransformationId
import Juvix.Compiler.Core.Data.TransformationId.Parser.Base
import Juvix.Prelude

parseTransformations :: Text -> Either Text [TransformationId]
parseTransformations = parseTransformations' @TransformationId @PipelineId

completionsString :: String -> [String]
completionsString = completionsString' @TransformationId @PipelineId

completions :: Text -> [Text]
completions = completions' @TransformationId @PipelineId
12 changes: 12 additions & 0 deletions src/Juvix/Compiler/Casm/Data/TransformationId/Strings.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module Juvix.Compiler.Casm.Data.TransformationId.Strings where

import Juvix.Prelude

strCairoPipeline :: Text
strCairoPipeline = "pipeline-cairo"

strIdentity :: Text
strIdentity = "identity"

strPeephole :: Text
strPeephole = "peephole"
17 changes: 17 additions & 0 deletions src/Juvix/Compiler/Casm/Pipeline.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module Juvix.Compiler.Casm.Pipeline
( module Juvix.Compiler.Casm.Pipeline,
Options,
Code,
)
where

import Juvix.Compiler.Casm.Transformation
import Juvix.Compiler.Pipeline.EntryPoint (EntryPoint)

-- | Perform transformations on CASM necessary before the translation to Cairo
-- bytecode
toCairo' :: Code -> Sem r Code
toCairo' = applyTransformations toCairoTransformations

toCairo :: (Member (Reader EntryPoint) r) => Code -> Sem r Code
toCairo = mapReader fromEntryPoint . toCairo'
18 changes: 18 additions & 0 deletions src/Juvix/Compiler/Casm/Transformation.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module Juvix.Compiler.Casm.Transformation
( module Juvix.Compiler.Casm.Transformation.Base,
module Juvix.Compiler.Casm.Transformation,
module Juvix.Compiler.Casm.Data.TransformationId,
)
where

import Juvix.Compiler.Casm.Data.TransformationId
import Juvix.Compiler.Casm.Transformation.Base
import Juvix.Compiler.Casm.Transformation.Optimize.Peephole

applyTransformations :: forall r. [TransformationId] -> Code -> Sem r Code
applyTransformations ts tbl = foldM (flip appTrans) tbl ts
where
appTrans :: TransformationId -> Code -> Sem r Code
appTrans = \case
IdentityTrans -> return
Peephole -> return . peephole
17 changes: 17 additions & 0 deletions src/Juvix/Compiler/Casm/Transformation/Base.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module Juvix.Compiler.Casm.Transformation.Base
( module Juvix.Compiler.Casm.Transformation.Base,
module Juvix.Compiler.Casm.Language,
module Juvix.Compiler.Tree.Options,
)
where

import Juvix.Compiler.Casm.Language
import Juvix.Compiler.Tree.Options

mapT :: ([Instruction] -> [Instruction]) -> [Instruction] -> [Instruction]
mapT f = go
where
go :: [Instruction] -> [Instruction]
go = \case
i : is -> f (i : go is)
[] -> f []
78 changes: 78 additions & 0 deletions src/Juvix/Compiler/Casm/Transformation/Optimize/Peephole.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
module Juvix.Compiler.Casm.Transformation.Optimize.Peephole where

import Juvix.Compiler.Casm.Extra.Base
import Juvix.Compiler.Casm.Language
import Juvix.Compiler.Casm.Transformation.Base

peephole :: [Instruction] -> [Instruction]
peephole = mapT go
where
go :: [Instruction] -> [Instruction]
go = \case
Nop : is -> is
Jump InstrJump {..} : lab@(Label LabelRef {..}) : is
| not _instrJumpIncAp,
Val (Lab (LabelRef sym _)) <- _instrJumpTarget,
sym == _labelRefSymbol ->
lab : is
Call InstrCall {..} : Return : Assign a1 : Return : is
| _instrCallRel,
Imm 3 <- _instrCallTarget,
Just k1 <- getAssignApFp a1 ->
fixAssignAp $
mkAssignAp (Val (Ref (MemRef Ap k1)))
: Return
: is
Call InstrCall {..} : Return : Assign a1 : Assign a2 : Return : is
| _instrCallRel,
Imm 3 <- _instrCallTarget,
Just k1 <- getAssignApFp a1,
Just k2 <- getAssignApFp a2 ->
fixAssignAp $
mkAssignAp (Val (Ref (MemRef Ap k1)))
: mkAssignAp (Val (Ref (MemRef Ap (k2 - 1))))
: Return
: is
Call InstrCall {..} : Return : Jump InstrJump {..} : is
| _instrCallRel,
Imm 3 <- _instrCallTarget,
Val tgt@(Lab {}) <- _instrJumpTarget,
not _instrJumpIncAp ->
let call =
InstrCall
{ _instrCallTarget = tgt,
_instrCallRel = _instrJumpRel
}
in Call call : Return : is
is -> is

fixAssignAp :: [Instruction] -> [Instruction]
fixAssignAp = \case
Assign a : Return : is
| Just (-1) <- getAssignAp Ap a ->
Return : is
Assign a1 : Assign a2 : Return : is
| Just (-2) <- getAssignAp Ap a1,
Just (-2) <- getAssignAp Ap a2 ->
Return : is
Assign a1 : Assign a2 : Return : is
| Just (-1) <- getAssignAp Ap a1,
Just (-3) <- getAssignAp Ap a2 ->
mkAssignAp (Val (Ref (MemRef Ap (-2)))) : Return : is
is -> is

getAssignAp :: Reg -> InstrAssign -> Maybe Offset
getAssignAp reg InstrAssign {..}
| MemRef Ap 0 <- _instrAssignResult,
Val (Ref (MemRef r k)) <- _instrAssignValue,
r == reg,
_instrAssignIncAp =
Just k
| otherwise =
Nothing

getAssignApFp :: InstrAssign -> Maybe Offset
getAssignApFp instr = case getAssignAp Fp instr of
Just k
| k <= -3 -> Just (k + 2)
_ -> Nothing
27 changes: 16 additions & 11 deletions src/Juvix/Compiler/Casm/Translation/FromReg.hs
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,18 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI

-- To ensure that memory is accessed sequentially at all times, we divide
-- instructions into basic blocks. Within each basic block, the `ap` offset
-- is known at each instruction, which allows to statically associate `fp`
-- offsets to variables while still generating only sequential assignments
-- to `[ap]` with increasing `ap`. When the `ap` offset can no longer be
-- statically determined for new variables (e.g. due to an intervening
-- recursive call), we switch to the next basic block by "calling" it with
-- the `call` instruction (see `goCallBlock`). The arguments of the basic
-- block call are the variables live at the beginning of the block. Note
-- that the `fp` offsets of "old" variables are still statically determined
-- even after the current `ap` offset becomes unknown -- the arbitrary
-- increase of `ap` does not influence the previous variable associations.
-- (i.e. how much `ap` increased since the start of the basic block) is
-- known at each instruction, which allows to statically associate `fp`
-- offsets (i.e. offsets relative to `fp`) to variables while still
-- generating only sequential assignments to `[ap]` with increasing `ap`.
-- When the `ap` offset can no longer be statically determined for new
-- variables (e.g. due to an intervening recursive call), we switch to the
-- next basic block by "calling" it with the `call` instruction (see
-- `goCallBlock`). The arguments of the basic block call are the variables
-- live at the beginning of the block. Note that the `fp` offsets of "old"
-- variables are still statically determined even after the current `ap`
-- offset becomes unknown -- the arbitrary increase of `ap` does not
-- influence the previous variable associations.
goBlock :: forall r. (Members '[LabelInfoBuilder, CasmBuilder, Output Instruction] r) => StdlibBuiltins -> LabelRef -> HashSet Reg.VarRef -> Maybe Reg.VarRef -> Reg.Block -> Sem r ()
goBlock blts failLab liveVars0 mout Reg.Block {..} = do
mapM_ goInstr _blockBody
Expand Down Expand Up @@ -645,7 +647,10 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI
ap0 <- getAP
vars <- getVars
bltOff <- getBuiltinOffset
mapM_ (goCaseBranch ap0 vars bltOff symMap labEnd) _instrCaseBranches
-- reversing `_instrCaseBranches` typically results in better
-- opportunities for peephole optimization (the last jump to branch
-- may be removed by the peephole optimizer)
mapM_ (goCaseBranch ap0 vars bltOff symMap labEnd) (reverse _instrCaseBranches)
mapM_ (goDefaultLabel symMap) defaultTags
whenJust _instrCaseDefault $
goLocalBlock ap0 vars bltOff liveVars _instrCaseOutVar
Expand Down
Loading

0 comments on commit 802d82f

Please sign in to comment.