Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Cairo VM input hints #2709

Merged
merged 11 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ env:
SKIP: ormolu,format-juvix-files,typecheck-juvix-examples
VAMPIRREPO: anoma/vamp-ir
VAMPIRVERSION: v0.1.3
CAIRO_VM_VERSION: 42e04161de82d7e5381258def4b65087c8944660
CAIRO_VM_VERSION: ae06ba04f3b6864546b6baeeebf1b0735ddccb5d

jobs:
pre-commit:
Expand Down Expand Up @@ -127,7 +127,7 @@ jobs:
id: cache-cairo-vm
uses: actions/cache@v4
with:
path: ${{ env.HOME }}/.local/bin/cairo-vm-cli
path: ${{ env.HOME }}/.local/bin/juvix-cairo-vm
key: ${{ runner.os }}-cairo-vm-${{ env.CAIRO_VM_VERSION }}

- name: Install Rust toolchain
Expand All @@ -138,17 +138,17 @@ jobs:
uses: actions/checkout@v4
if: steps.cache-cairo-vm.outputs.cache-hit != 'true'
with:
repository: lambdaclass/cairo-vm
path: cairo-vm
repository: anoma/juvix-cairo-vm
path: juvix-cairo-vm
ref: ${{ env.CAIRO_VM_VERSION }}

- name: Install Cairo VM
if: steps.cache-cairo-vm.outputs.cache-hit != 'true'
shell: bash
run: |
cd cairo-vm
cd juvix-cairo-vm
cargo build --release
cp target/release/cairo-vm-cli $HOME/.local/bin/cairo-vm-cli
cp target/release/juvix-cairo-vm $HOME/.local/bin/juvix-cairo-vm

- name: Install run_cairo_vm.sh
shell: bash
Expand Down Expand Up @@ -323,7 +323,7 @@ jobs:
id: cache-cairo-vm
uses: actions/cache@v4
with:
path: ${{ env.HOME }}/.local/bin/cairo-vm-cli
path: ${{ env.HOME }}/.local/bin/juvix-cairo-vm
key: ${{ runner.os }}-cairo-vm-${{ env.CAIRO_VM_VERSION }}

- name: Install Rust toolchain
Expand All @@ -334,24 +334,22 @@ jobs:
uses: actions/checkout@v4
if: steps.cache-cairo-vm.outputs.cache-hit != 'true'
with:
repository: lambdaclass/cairo-vm
path: cairo-vm
repository: anoma/juvix-cairo-vm
path: juvix-cairo-vm
ref: ${{ env.CAIRO_VM_VERSION }}

- name: Install Cairo VM
if: steps.cache-cairo-vm.outputs.cache-hit != 'true'
shell: bash
run: |
cd cairo-vm
cd juvix-cairo-vm
cargo build --release
cp -a target/release/cairo-vm-cli $HOME/.local/bin/cairo-vm-cli
chmod a+x $HOME/.local/bin/cairo-vm-cli
cp -a target/release/juvix-cairo-vm $HOME/.local/bin/juvix-cairo-vm

- name: Install run_cairo_vm.sh
shell: bash
run: |
cp -a main/scripts/run_cairo_vm.sh $HOME/.local/bin/run_cairo_vm.sh
chmod a+x $HOME/.local/bin/run_cairo_vm.sh

- name: Make runtime
run: |
Expand Down
5 changes: 4 additions & 1 deletion app/Commands/Dev/Casm/Run.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@ module Commands.Dev.Casm.Run where

import Commands.Base
import Commands.Dev.Casm.Run.Options
import Juvix.Compiler.Casm.Extra.InputInfo qualified as Casm
import Juvix.Compiler.Casm.Interpreter qualified as Casm
import Juvix.Compiler.Casm.Translation.FromSource qualified as Casm
import Juvix.Compiler.Casm.Validate qualified as Casm

runCommand :: forall r. (Members '[EmbedIO, App] r) => CasmRunOptions -> Sem r ()
runCommand opts = do
afile :: Path Abs File <- fromAppPathFile file
dfile :: Maybe (Path Abs File) <- maybe (return Nothing) (fromAppPathFile >=> return . Just) (opts ^. casmRunDataFile)
inputInfo <- liftIO (Casm.readInputInfo dfile)
s <- readFile afile
case Casm.runParser afile s of
Left err -> exitJuvixError (JuvixError err)
Right (labi, code) ->
case Casm.validate labi code of
Left err -> exitJuvixError (JuvixError err)
Right () -> print (Casm.runCode labi code)
Right () -> print (Casm.runCode inputInfo labi code)
where
file :: AppPath File
file = opts ^. casmRunInputFile
6 changes: 4 additions & 2 deletions app/Commands/Dev/Casm/Run/Options.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ module Commands.Dev.Casm.Run.Options where

import CommonOptions

newtype CasmRunOptions = CasmRunOptions
{ _casmRunInputFile :: AppPath File
data CasmRunOptions = CasmRunOptions
{ _casmRunInputFile :: AppPath File,
_casmRunDataFile :: Maybe (AppPath File)
}
deriving stock (Data)

Expand All @@ -12,4 +13,5 @@ makeLenses ''CasmRunOptions
parseCasmRunOptions :: Parser CasmRunOptions
parseCasmRunOptions = do
_casmRunInputFile <- parseInputFile FileExtCasm
_casmRunDataFile <- optional parseProgramInputFile
pure CasmRunOptions {..}
13 changes: 13 additions & 0 deletions app/CommonOptions.hs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ parseInputFiles exts' = do
parseInputFile :: FileExt -> Parser (AppPath File)
parseInputFile = parseInputFiles . NonEmpty.singleton

parseProgramInputFile :: Parser (AppPath File)
parseProgramInputFile = do
_pathPath <-
option
somePreFileOpt
( long "program_input"
<> metavar "JSON_FILE"
<> help "Path to program input json file"
<> completer (extCompleter FileExtJson)
<> action "file"
)
pure AppPath {_pathIsInput = True, ..}

parseGenericInputFile :: Parser (AppPath File)
parseGenericInputFile = do
_pathPath <-
Expand Down
1 change: 1 addition & 0 deletions package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ dependencies:
- primitive == 0.8.*
- process == 1.6.*
- safe == 0.3.*
- scientific == 0.3.*
- singletons == 3.0.*
- singletons-base == 3.3.*
- singletons-th == 3.3.*
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_cairo_vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

BASE=`basename "$1" .json`

cairo-vm-cli "$1" --print_output --proof_mode --trace_file ${BASE}.trace --air_public_input=${BASE}_public_input.json --air_private_input=${BASE}_private_input.json --memory_file=${BASE}_memory.mem --layout=small
juvix-cairo-vm "$@" --print_output --proof_mode --trace_file ${BASE}.trace --air_public_input=${BASE}_public_input.json --air_private_input=${BASE}_private_input.json --memory_file=${BASE}_memory.mem --layout=small
23 changes: 22 additions & 1 deletion src/Juvix/Compiler/Backend/Cairo/Data/Result.hs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
module Juvix.Compiler.Backend.Cairo.Data.Result where

import Data.Aeson as Aeson hiding (Result)
import Data.Aeson.Types hiding (Result)
import Data.Vector qualified as V
import Juvix.Prelude hiding ((.=))

data Result = Result
{ _resultData :: [Text],
_resultStart :: Int,
_resultEnd :: Int,
_resultMain :: Int,
_resultHints :: [(Int, Text)],
_resultBuiltins :: [Text]
}

Expand All @@ -19,7 +22,7 @@ instance ToJSON Result where
[ "data" .= toJSON _resultData,
"attributes" .= Array mempty,
"builtins" .= toJSON _resultBuiltins,
"hints" .= object [],
"hints" .= object (map mkHint _resultHints),
"identifiers"
.= object
[ "__main__.__start__"
Expand All @@ -46,3 +49,21 @@ instance ToJSON Result where
[ "references" .= Array mempty
]
]
where
mkHint :: (Int, Text) -> Pair
mkHint (pc, hintCode) = (fromString (show pc), Array $ V.fromList [hint])
where
hint =
object
[ "accessible_scopes" .= Array mempty,
"code" .= hintCode,
"flow_tracking_data"
.= object
[ "ap_tracking"
.= object
[ "group" .= Number 0,
"offset" .= Number 0
],
"reference_ids" .= object []
]
]
18 changes: 18 additions & 0 deletions src/Juvix/Compiler/Backend/Cairo/Extra/Serialization.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,21 @@ serialize elems =
_resultStart = 0,
_resultEnd = length initializeOutput + length elems + length finalizeOutput,
_resultMain = 0,
_resultHints = hints,
_resultBuiltins = ["output"]
}
where
hints :: [(Int, Text)]
hints = catMaybes $ zipWith mkHint elems [0 ..]

pcShift :: Int
pcShift = length initializeOutput

mkHint :: Element -> Int -> Maybe (Int, Text)
mkHint el pc = case el of
ElementHint Hint {..} -> Just (pc + pcShift, _hintCode)
_ -> Nothing

toHexText :: Natural -> Text
toHexText n = "0x" <> fromString (showHex n "")

Expand Down Expand Up @@ -48,6 +60,12 @@ serialize' = map goElement
goElement = \case
ElementInstruction i -> goInstr i
ElementImmediate f -> fieldToNatural f
ElementHint h -> goHint h

goHint :: Hint -> Natural
goHint Hint {..}
| _hintIncAp = 0x481280007fff8000
| otherwise = 0x401280007fff8000

goInstr :: Instruction -> Natural
goInstr Instruction {..} =
Expand Down
7 changes: 7 additions & 0 deletions src/Juvix/Compiler/Backend/Cairo/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ import Juvix.Data.Field
data Element
= ElementInstruction Instruction
| ElementImmediate FField
| ElementHint Hint

data Hint = Hint
{ _hintCode :: Text,
_hintIncAp :: Bool
}

data Instruction = Instruction
{ _instrOffDst :: Offset,
Expand Down Expand Up @@ -80,3 +86,4 @@ defaultInstruction =
}

makeLenses ''Instruction
makeLenses ''Hint
6 changes: 6 additions & 0 deletions src/Juvix/Compiler/Backend/Cairo/Translation/FromCasm.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ fromCasm instrs0 =
Casm.Call x -> goCall x
Casm.Return -> goReturn
Casm.Alloc x -> goAlloc x
Casm.Hint x -> goHint x
Casm.Trace {} -> []
Casm.Label {} -> []
Casm.Nop -> []
Expand Down Expand Up @@ -228,3 +229,8 @@ fromCasm instrs0 =
. updateOps False _instrAllocSize
. set instrApUpdate ApUpdateAdd
$ defaultInstruction

goHint :: Casm.Hint -> [Element]
goHint = \case
Casm.HintInput var -> [ElementHint (Hint ("Input(" <> var <> ")") True)]
Casm.HintAlloc size -> [ElementHint (Hint ("Alloc(" <> show size <> ")") True)]
38 changes: 38 additions & 0 deletions src/Juvix/Compiler/Casm/Data/InputInfo.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
module Juvix.Compiler.Casm.Data.InputInfo where

import Data.Aeson
import Data.Aeson.Key
import Data.Aeson.KeyMap qualified as KeyMap
import Data.Aeson.Types
import Data.HashMap.Strict qualified as HashMap
import Data.Scientific
import Juvix.Data.Field
import Juvix.Prelude

newtype InputInfo = InputInfo
{ _inputInfoMap :: HashMap Text FField
}
deriving stock (Generic, Show)

makeLenses ''InputInfo

instance FromJSON InputInfo where
parseJSON = \case
Object obj -> do
lst <-
forM (KeyMap.toList obj) $ \(k, v) -> do
v' <- parseFField v
return (toText k, v')
return
. InputInfo
. HashMap.fromList
$ lst
v -> typeMismatch "Object" v
where
parseFField :: Value -> Parser FField
parseFField = \case
Number x
| isInteger x ->
return $ fieldFromInteger cairoFieldSize (fromRight 0 $ floatingOrInteger @Double x)
v ->
typeMismatch "Integer" v
12 changes: 12 additions & 0 deletions src/Juvix/Compiler/Casm/Extra/InputInfo.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module Juvix.Compiler.Casm.Extra.InputInfo where

import Juvix.Compiler.Casm.Data.InputInfo
import Juvix.Prelude
import Juvix.Prelude.Aeson

readInputInfo :: Maybe (Path Abs File) -> IO InputInfo
readInputInfo inputFile = case inputFile of
Just file ->
fromJust <$> readJSONFile (toFilePath file)
Nothing ->
return $ InputInfo mempty
20 changes: 17 additions & 3 deletions src/Juvix/Compiler/Casm/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Data.HashMap.Strict qualified as HashMap
import Data.Vector qualified as Vec
import Data.Vector.Mutable qualified as MV
import GHC.IO qualified as GHC
import Juvix.Compiler.Casm.Data.InputInfo
import Juvix.Compiler.Casm.Data.LabelInfo
import Juvix.Compiler.Casm.Error
import Juvix.Compiler.Casm.Interpreter.Error
Expand All @@ -19,12 +20,12 @@ import Juvix.Data.Field

type Memory s = MV.MVector s (Maybe FField)

runCode :: LabelInfo -> [Instruction] -> FField
runCode :: InputInfo -> LabelInfo -> [Instruction] -> FField
runCode = hRunCode stderr

-- | Runs Cairo Assembly. Returns the value of `[ap - 1]` at program exit.
hRunCode :: Handle -> LabelInfo -> [Instruction] -> FField
hRunCode hout (LabelInfo labelInfo) instrs0 = runST goCode
hRunCode :: Handle -> InputInfo -> LabelInfo -> [Instruction] -> FField
hRunCode hout inputInfo (LabelInfo labelInfo) instrs0 = runST goCode
where
instrs :: Vec.Vector Instruction
instrs = Vec.fromList instrs0
Expand Down Expand Up @@ -61,6 +62,7 @@ hRunCode hout (LabelInfo labelInfo) instrs0 = runST goCode
Return -> goReturn pc ap fp mem
Alloc x -> goAlloc x pc ap fp mem
Trace x -> goTrace x pc ap fp mem
Hint x -> goHint x pc ap fp mem
Label {} -> go (pc + 1) ap fp mem
Nop -> go (pc + 1) ap fp mem

Expand Down Expand Up @@ -241,6 +243,18 @@ hRunCode hout (LabelInfo labelInfo) instrs0 = runST goCode
GHC.unsafePerformIO (hPrint hout v >> return (pure ()))
go (pc + 1) ap fp mem

goHint :: Hint -> Address -> Address -> Address -> Memory s -> ST s FField
goHint hint pc ap fp mem = case hint of
HintInput var -> do
let val =
fromMaybe (throwRunError "invalid input") $
HashMap.lookup var (inputInfo ^. inputInfoMap)
mem' <- writeMem mem ap val
go (pc + 1) (ap + 1) fp mem'
HintAlloc size -> do
mem' <- writeMem mem ap (fieldFromInteger fsize (fromIntegral ap + 1))
go (pc + 1) (ap + size + 1) fp mem'

goFinish :: Address -> Memory s -> ST s FField
goFinish ap mem = do
checkGaps mem
Expand Down
5 changes: 5 additions & 0 deletions src/Juvix/Compiler/Casm/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ data ExtraOpcode
| -- | Sets the result to zero if arg1 < arg2, or to non-zero otherwise
IntLt

data Hint
= HintInput Text
| HintAlloc Int

data Instruction
= Assign InstrAssign
| -- | Extra binary operation not directly available in Cairo Assembly bytecode,
Expand All @@ -85,6 +89,7 @@ data Instruction
| Return
| Alloc InstrAlloc
| Trace InstrTrace
| Hint Hint
| Label LabelRef
| Nop

Expand Down
Loading
Loading