Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: make some partial functions total #23

Merged
merged 1 commit into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions NKL/FFI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import NKL.Trace
namespace NKL
open NKL.KLR

local instance : MonadLift (Except String) IO where
local instance : MonadLift Err IO where
monadLift
| .ok x => return x
| .error s => throw $ .userError s
Expand All @@ -21,4 +21,4 @@ def parse_json (s : String) : IO Unit := do
let kernel <- Python.Parsing.parse s
let stmts <- NKL.Trace.runNKIKernel kernel
for s in stmts do
IO.println (" " ++ Lean.format s) --s!"{s}\n{repr s}"
IO.println (" " ++ Lean.format s)
32 changes: 0 additions & 32 deletions NKL/KLR/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -45,38 +45,6 @@ inductive Const where
| string (value : String)
deriving Repr, BEq

namespace Const

-- Python-like rules for conversion to boolean
def isTrue : Const -> Bool
| .none => false
| .bool b => b
| .int i => i != 0
| .float f => f != 0.0
| .string s => s != ""

-- Python-like rules for conversion to integer
def toInt : Const -> Except String Int
| .none => throw "none cannot be converted to an integer"
| .bool true => return 1
| .bool false => return 0
| .int i => return i
| .float f =>
-- Python is a bit strange here, it truncates both
-- positive and negative numbers toward zero
if f < 0.0 then
return (Int.ofNat (Float.floor (-f)).toUInt64.toNat).neg
else
return Int.ofNat (Float.floor f).toUInt64.toNat
| .string s =>
-- Fortunately, Lean's String.toInt appears to be compatible
-- with Python's int(string) conversion.
match s.toInt? with
| .none => throw s!"string {s} cannot be converted to an integer"
| .some i => return i

end Const

-- This corresponds to the "Quasi-Affine Expressions" in Neuron.
-- Note, `floor` is the usual integer division.
inductive IndexExpr where
Expand Down
11 changes: 6 additions & 5 deletions NKL/KLR/Encode.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.Util
import NKL.KLR.Basic

/-!
Expand All @@ -12,15 +13,15 @@ import NKL.KLR.Basic

namespace NKL.KLR

-- All of the encode function are pure; decoding uses an instance of EStateM.
-- All of the encode function are pure; decoding uses an instance of StM.

abbrev DecodeM := EStateM String ByteArray.Iterator
abbrev DecodeM := StM ByteArray.Iterator

def decode' (f : DecodeM a) (ba : ByteArray) : Option a :=
EStateM.run' f ba.iter
f.run' ba.iter

def decode (f : DecodeM a) (ba : ByteArray) : Except String a :=
match EStateM.run f ba.iter with
def decode (f : DecodeM a) (ba : ByteArray) : Err a :=
match f.run ba.iter with
| .ok x _ => .ok x
| .error s _ => .error s

Expand Down
10 changes: 3 additions & 7 deletions NKL/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import Lean
import NKL.Util

/-!
# Abstract syntax of Python functions
Expand Down Expand Up @@ -172,12 +173,7 @@ open Lean
-- span (Pos) is saved while traversing the tree to identify the location
-- of any errors in the original program

abbrev Parser := EStateM String Pos

local instance : MonadLift (Except String) Parser where
monadLift
| .ok x => return x
| .error s => throw s
abbrev Parser := StM Pos

private def str : Json -> Parser String :=
monadLift ∘ Json.getStr?
Expand Down Expand Up @@ -364,7 +360,7 @@ def kernel (j : Json) : Parser Kernel := do
let globals <- field (dict global) j "globals"
return Kernel.mk name funcs args kwargs globals

def parse (s : String) : Except String Kernel := do
def parse (s : String) : Err Kernel := do
let jsn <- Json.parse s
match kernel jsn {} with
| .ok x _ => .ok x
Expand Down
2 changes: 1 addition & 1 deletion NKL/Trace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import NKL.Trace.NKI

namespace NKL.Trace

def runNKIKernel (k : NKL.Python.Kernel) : Except String (List NKL.KLR.Stmt) :=
def runNKIKernel (k : NKL.Python.Kernel) : Err (List NKL.KLR.Stmt) :=
tracer ⟨ .ofList NKIEnv, #[] ⟩ do
traceKernel k
let g <- get
Expand Down
36 changes: 34 additions & 2 deletions NKL/Trace/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,44 @@ import NKL.Trace.Tensor
Basic tracing definitions only deal with Terms (not Python sources)
-/

namespace NKL.KLR.Const

-- Python-like rules for conversion to boolean
def isTrue : Const -> Bool
| .none => false
| .bool b => b
| .int i => i != 0
| .float f => f != 0.0
| .string s => s != ""

-- Python-like rules for conversion to integer
def toInt : Const -> Err Int
| .none => throw "none cannot be converted to an integer"
| .bool true => return 1
| .bool false => return 0
| .int i => return i
| .float f =>
-- Python is a bit strange here, it truncates both
-- positive and negative numbers toward zero
if f < 0.0 then
return (Int.ofNat (Float.floor (-f)).toUInt64.toNat).neg
else
return Int.ofNat (Float.floor f).toUInt64.toNat
| .string s =>
-- Fortunately, Lean's String.toInt appears to be compatible
-- with Python's int(string) conversion.
match s.toInt? with
| .none => throw s!"string {s} cannot be converted to an integer"
| .some i => return i

end NKL.KLR.Const

namespace NKL.Trace
open NKL.KLR

-- Operators within index expressions

def indexBinOp : String -> KLR.IndexExpr -> KLR.IndexExpr -> ErrorM KLR.IndexExpr
def indexBinOp : String -> KLR.IndexExpr -> KLR.IndexExpr -> Err KLR.IndexExpr
| "Add" , l, r => return .add l r
| "Sub" , l, r => return .add l r.neg
| "Mult", .int i, e
Expand All @@ -27,7 +59,7 @@ def indexBinOp : String -> KLR.IndexExpr -> KLR.IndexExpr -> ErrorM KLR.IndexExp
| "Mod" , e, .int i => return .mod e i
| _, _, _ => throw "invalid index expression"

def indexUnOp : String -> KLR.IndexExpr -> ErrorM KLR.IndexExpr
def indexUnOp : String -> KLR.IndexExpr -> Err KLR.IndexExpr
| "USub", e => return .neg e
| _, _ => throw "invalid index expresssion"

Expand Down
6 changes: 3 additions & 3 deletions NKL/Trace/Builtin.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ import NKL.Trace.Types
namespace NKL.Trace
open NKL.KLR

abbrev BuiltinAttr := String -> ErrorM Term
abbrev BuiltinAttr := String -> Err Term
abbrev GlobalAttr := String -> TraceM Term

abbrev BuiltinFn := List Expr -> List (String × Expr) -> ErrorM Term
abbrev BuiltinFn := List Expr -> List (String × Expr) -> Err Term
abbrev GlobalFn := List Term -> List (String × Term) -> TraceM Term

def noattrs [Monad m] [MonadExcept String m] : Name -> String -> m a :=
Expand Down Expand Up @@ -61,7 +61,7 @@ def simple_object {a : Type}
, call := uncallable name
}
where
attr_fn (attr : String) : Except String Term :=
attr_fn (attr : String) : Err Term :=
match attrs.find? (fun x => x.fst == attr) with
| none => .error s!"{attr} is not an attribute of {name}"
| some (_,fn) => .ok (.object $ simple_function (name.str attr) (fn x))
Expand Down
62 changes: 27 additions & 35 deletions NKL/Trace/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import NKL.Trace.Basic
namespace NKL.Trace
open NKL.Python

def const : Const -> ErrorM Term
def const : Const -> Err Term
| .none => return .expr (.const $ .none) .none
| .bool b => return .expr (.const $ .bool b) .bool
| .int i => return .expr (.const $ .int i) .int
Expand Down Expand Up @@ -132,21 +132,21 @@ dead-code elimination for simple assignments.
-/

-- Convert an expression in assignment context (an L-Value).
partial def LValue : Expr -> Tracer Term
def LValue : Expr -> Tracer Term
| .exprPos e' p => withPos p (lval e')
where
lval : Expr' -> Tracer Term
| .name id .store => return .expr (.var id) (.any "?".toName)
| .tuple l .store => return .tuple (<- l.mapM LValue)
| .list l .store => return .list (<- l.mapM LValue)
| .tuple l .store => return .tuple (<- LValue ▷ l)
govereau marked this conversation as resolved.
Show resolved Hide resolved
| .list l .store => return .list (<- LValue ▷ l)
| _ => throw "cannot assign to expression"

-- Convert an R-Value to a pure expression, emitting
-- additional assignments as needed.
partial def RValue : Term -> Tracer Term
def RValue : Term -> Tracer Term
| .object o => return .object o
| .tuple l => return .tuple (<- l.mapM RValue)
| .list l => return .list (<- l.mapM RValue)
| .tuple l => return .tuple (<- RValue ▷ l)
| .list l => return .list (<- RValue ▷ l)
| .expr e@(.call _ _ _) ty => do
let v := (<- genName).toString
add_stmt (.assign v e)
Expand Down Expand Up @@ -197,45 +197,37 @@ mutual
partial def expr : Expr -> Tracer Item
| .exprPos e' p => withPos p (expr' e')

partial def term (e : Expr) : Tracer Term := do
match (<- expr e) with
| .module n => return .expr (.var n.toString) (.any "?".toName)
| .global g => return .expr (.var g.name.toString) (.any "?".toName)
| .source _ => throw "invalid use of source function"
| .term t => return t
partial def term (e : Expr) : Tracer Term :=
return <- (<- expr e).toTerm

partial def term' (e : Expr') : Tracer Term := do
term (.exprPos e (<- getPos))
partial def term' (e : Expr') : Tracer Term :=
return <- (<- expr' e).toTerm

partial def klr (e : Expr) : Tracer KLR.Expr := do
match (<- term e) with
| .object obj => return .var obj.name.toString
| .tuple _ => throw "tuple cannot be converted to a KLR term"
| .list _ => throw "list cannot be converted to a KLR term"
| .expr e _ => return e
partial def klr (e : Expr) : Tracer KLR.Expr :=
return <- (<- term e).toKLR

partial def integer (e : Expr) : Tracer Int := do
match (<- term e) with
| .expr (.const c) _ => return (<- c.toInt)
| _ => throw "invalid tensor dimension"
match <- klr e with
| .const c => return (<- c.toInt)
govereau marked this conversation as resolved.
Show resolved Hide resolved
| _ => throw "expecting integer"

partial def expr' : Expr' -> Tracer Item
| .const c => return .term (<- const c)
| .tensor s dty => do
let shape <- s.mapM integer
let shape <- integer ▷ s
let name <- genName "t".toName
return .term (.expr (.tensor ⟨ name.toString, dty, shape ⟩) (.tensor dty shape))
| .name id _ => lookup_item id.toName
| .attr (.exprPos e p) id _ => do withPos p ((<- expr' e).attr id)
| .tuple l _ => return .term (.tuple (<- l.mapM term))
| .list l _ => return .term (.list (<- l.mapM term))
| .tuple l _ => return .term (.tuple (<- term ▷ l))
| .list l _ => return .term (.list (<- term ▷ l))
| .subscript t [ .exprPos (.tuple ix _) _ ] _
| .subscript t ix _ => return .term (<- access (<- term t) (<- ix.mapM index))
| .subscript t ix _ => return .term (<- access (<- term t) (<- index ▷ ix))
| .slice _ _ _ => throw "syntax error"
| .boolOp op xs => return .term (<- boolOp op (<- xs.mapM term))
| .boolOp op xs => return .term (<- boolOp op (<- term ▷ xs))
| .binOp op l r => return .term (<- binOp op (<- term l) (<- term r))
| .unaryOp op e => return .term (<- unOp op (<- term e))
| .compare l ops cs => return .term (<- compare (<- term l) ops (<- cs.mapM term))
| .compare l ops cs => return .term (<- compare (<- term l) ops (<- term ▷ cs))
| .ifExp tst tru fls => do
let tst <- (<- term tst).isTrue
let tru <- expr tru -- eagerly evaluate both branches
Expand All @@ -244,10 +236,10 @@ partial def expr' : Expr' -> Tracer Item
| .call f args kws => do
match <- expr f with
| .module n => throw s!"module {n} not callable"
| .global g => return .term (<- g.call (<- args.mapM term) (<- kws.mapM (keyword term)))
| .term t => return .term (<- t.call (<- args.mapM klr) (<- kws.mapM (keyword klr)))
| .global g => return .term (<- g.call (<- term ▷ args) (<- keyword term ▷ kws))
| .term t => return .term (<- t.call (<- klr ▷ args) (<- keyword klr ▷ kws))
| .source f => do
function_call f (<- args.mapM term) (<- kws.mapM (keyword term))
function_call f (<- term ▷ args) (<- keyword term ▷ kws)
return .term (.expr (.const .none) .none)

partial def keyword (f : Expr -> Tracer a) : Keyword -> Tracer (String × a)
Expand All @@ -263,7 +255,7 @@ partial def stmt' : Stmt' -> Tracer Unit
| .assert e => do
let t <- term e
if (<- t.isFalse) then throw "assertion failed"
| .assign xs e => do assign (<- xs.mapM LValue) (<- term e)
| .assign xs e => do assign (<- LValue ▷ xs) (<- term e)
| .augAssign x op e => do
stmt' (.assign [x] (.exprPos (.binOp op x e) (<- getPos)))
| .annAssign _ _ .none => return ()
Expand Down Expand Up @@ -349,7 +341,7 @@ def traceKernel (k : Kernel) : Tracer Unit := do
let kwargs <- k.kwargs.mapM fun (x,e) => return (x, <- term' e)
function_call f args kwargs

def runKernel (k : Kernel) : Except String (List KLR.Stmt) :=
def runKernel (k : Kernel) : Err (List KLR.Stmt) :=
tracer ⟨ ∅, #[] ⟩ do
traceKernel k
let g <- get
Expand Down
Loading
Loading