Skip to content

Commit

Permalink
refactor: make some partial functions total
Browse files Browse the repository at this point in the history
This patch cleans up a few cases of partial definitions, making them
total. There is also a small amount of refactoring and renaming to be
consistent with TensorLib.
  • Loading branch information
govereau committed Jan 27, 2025
1 parent 2f03c87 commit c65ccb5
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 103 deletions.
4 changes: 2 additions & 2 deletions NKL/FFI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import NKL.Trace
namespace NKL
open NKL.KLR

local instance : MonadLift (Except String) IO where
local instance : MonadLift Err IO where
monadLift
| .ok x => return x
| .error s => throw $ .userError s
Expand All @@ -21,4 +21,4 @@ def parse_json (s : String) : IO Unit := do
let kernel <- Python.Parsing.parse s
let stmts <- NKL.Trace.runNKIKernel kernel
for s in stmts do
IO.println (" " ++ Lean.format s) --s!"{s}\n{repr s}"
IO.println (" " ++ Lean.format s)
32 changes: 0 additions & 32 deletions NKL/KLR/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -45,38 +45,6 @@ inductive Const where
| string (value : String)
deriving Repr, BEq

namespace Const

-- Python-like rules for conversion to boolean
def isTrue : Const -> Bool
| .none => false
| .bool b => b
| .int i => i != 0
| .float f => f != 0.0
| .string s => s != ""

-- Python-like rules for conversion to integer
def toInt : Const -> Except String Int
| .none => throw "none cannot be converted to an integer"
| .bool true => return 1
| .bool false => return 0
| .int i => return i
| .float f =>
-- Python is a bit strange here, it truncates both
-- positive and negative numbers toward zero
if f < 0.0 then
return (Int.ofNat (Float.floor (-f)).toUInt64.toNat).neg
else
return Int.ofNat (Float.floor f).toUInt64.toNat
| .string s =>
-- Fortunately, Lean's String.toInt appears to be compatible
-- with Python's int(string) conversion.
match s.toInt? with
| .none => throw s!"string {s} cannot be converted to an integer"
| .some i => return i

end Const

-- This corresponds to the "Quasi-Affine Expressions" in Neuron.
-- Note, `floor` is the usual integer division.
inductive IndexExpr where
Expand Down
11 changes: 6 additions & 5 deletions NKL/KLR/Encode.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.Util
import NKL.KLR.Basic

/-!
Expand All @@ -12,15 +13,15 @@ import NKL.KLR.Basic

namespace NKL.KLR

-- All of the encode function are pure; decoding uses an instance of EStateM.
-- All of the encode function are pure; decoding uses an instance of StM.

abbrev DecodeM := EStateM String ByteArray.Iterator
abbrev DecodeM := StM ByteArray.Iterator

def decode' (f : DecodeM a) (ba : ByteArray) : Option a :=
EStateM.run' f ba.iter
f.run' ba.iter

def decode (f : DecodeM a) (ba : ByteArray) : Except String a :=
match EStateM.run f ba.iter with
def decode (f : DecodeM a) (ba : ByteArray) : Err a :=
match f.run ba.iter with
| .ok x _ => .ok x
| .error s _ => .error s

Expand Down
10 changes: 3 additions & 7 deletions NKL/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import Lean
import NKL.Util

/-!
# Abstract syntax of Python functions
Expand Down Expand Up @@ -172,12 +173,7 @@ open Lean
-- span (Pos) is saved while traversing the tree to identify the location
-- of any errors in the original program

abbrev Parser := EStateM String Pos

local instance : MonadLift (Except String) Parser where
monadLift
| .ok x => return x
| .error s => throw s
abbrev Parser := StM Pos

private def str : Json -> Parser String :=
monadLift ∘ Json.getStr?
Expand Down Expand Up @@ -364,7 +360,7 @@ def kernel (j : Json) : Parser Kernel := do
let globals <- field (dict global) j "globals"
return Kernel.mk name funcs args kwargs globals

def parse (s : String) : Except String Kernel := do
def parse (s : String) : Err Kernel := do
let jsn <- Json.parse s
match kernel jsn {} with
| .ok x _ => .ok x
Expand Down
2 changes: 1 addition & 1 deletion NKL/Trace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import NKL.Trace.NKI

namespace NKL.Trace

def runNKIKernel (k : NKL.Python.Kernel) : Except String (List NKL.KLR.Stmt) :=
def runNKIKernel (k : NKL.Python.Kernel) : Err (List NKL.KLR.Stmt) :=
tracer ⟨ .ofList NKIEnv, #[] ⟩ do
traceKernel k
let g <- get
Expand Down
36 changes: 34 additions & 2 deletions NKL/Trace/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,44 @@ import NKL.Trace.Tensor
Basic tracing definitions only deal with Terms (not Python sources)
-/

namespace NKL.KLR.Const

-- Python-like rules for conversion to boolean
def isTrue : Const -> Bool
| .none => false
| .bool b => b
| .int i => i != 0
| .float f => f != 0.0
| .string s => s != ""

-- Python-like rules for conversion to integer
def toInt : Const -> Err Int
| .none => throw "none cannot be converted to an integer"
| .bool true => return 1
| .bool false => return 0
| .int i => return i
| .float f =>
-- Python is a bit strange here, it truncates both
-- positive and negative numbers toward zero
if f < 0.0 then
return (Int.ofNat (Float.floor (-f)).toUInt64.toNat).neg
else
return Int.ofNat (Float.floor f).toUInt64.toNat
| .string s =>
-- Fortunately, Lean's String.toInt appears to be compatible
-- with Python's int(string) conversion.
match s.toInt? with
| .none => throw s!"string {s} cannot be converted to an integer"
| .some i => return i

end NKL.KLR.Const

namespace NKL.Trace
open NKL.KLR

-- Operators within index expressions

def indexBinOp : String -> KLR.IndexExpr -> KLR.IndexExpr -> ErrorM KLR.IndexExpr
def indexBinOp : String -> KLR.IndexExpr -> KLR.IndexExpr -> Err KLR.IndexExpr
| "Add" , l, r => return .add l r
| "Sub" , l, r => return .add l r.neg
| "Mult", .int i, e
Expand All @@ -27,7 +59,7 @@ def indexBinOp : String -> KLR.IndexExpr -> KLR.IndexExpr -> ErrorM KLR.IndexExp
| "Mod" , e, .int i => return .mod e i
| _, _, _ => throw "invalid index expression"

def indexUnOp : String -> KLR.IndexExpr -> ErrorM KLR.IndexExpr
def indexUnOp : String -> KLR.IndexExpr -> Err KLR.IndexExpr
| "USub", e => return .neg e
| _, _ => throw "invalid index expresssion"

Expand Down
6 changes: 3 additions & 3 deletions NKL/Trace/Builtin.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ import NKL.Trace.Types
namespace NKL.Trace
open NKL.KLR

abbrev BuiltinAttr := String -> ErrorM Term
abbrev BuiltinAttr := String -> Err Term
abbrev GlobalAttr := String -> TraceM Term

abbrev BuiltinFn := List Expr -> List (String × Expr) -> ErrorM Term
abbrev BuiltinFn := List Expr -> List (String × Expr) -> Err Term
abbrev GlobalFn := List Term -> List (String × Term) -> TraceM Term

def noattrs [Monad m] [MonadExcept String m] : Name -> String -> m a :=
Expand Down Expand Up @@ -61,7 +61,7 @@ def simple_object {a : Type}
, call := uncallable name
}
where
attr_fn (attr : String) : Except String Term :=
attr_fn (attr : String) : Err Term :=
match attrs.find? (fun x => x.fst == attr) with
| none => .error s!"{attr} is not an attribute of {name}"
| some (_,fn) => .ok (.object $ simple_function (name.str attr) (fn x))
Expand Down
62 changes: 27 additions & 35 deletions NKL/Trace/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import NKL.Trace.Basic
namespace NKL.Trace
open NKL.Python

def const : Const -> ErrorM Term
def const : Const -> Err Term
| .none => return .expr (.const $ .none) .none
| .bool b => return .expr (.const $ .bool b) .bool
| .int i => return .expr (.const $ .int i) .int
Expand Down Expand Up @@ -132,21 +132,21 @@ dead-code elimination for simple assignments.
-/

-- Convert an expression in assignment context (an L-Value).
partial def LValue : Expr -> Tracer Term
def LValue : Expr -> Tracer Term
| .exprPos e' p => withPos p (lval e')
where
lval : Expr' -> Tracer Term
| .name id .store => return .expr (.var id) (.any "?".toName)
| .tuple l .store => return .tuple (<- l.mapM LValue)
| .list l .store => return .list (<- l.mapM LValue)
| .tuple l .store => return .tuple (<- LValue ▷ l)
| .list l .store => return .list (<- LValue ▷ l)
| _ => throw "cannot assign to expression"

-- Convert an R-Value to a pure expression, emitting
-- additional assignments as needed.
partial def RValue : Term -> Tracer Term
def RValue : Term -> Tracer Term
| .object o => return .object o
| .tuple l => return .tuple (<- l.mapM RValue)
| .list l => return .list (<- l.mapM RValue)
| .tuple l => return .tuple (<- RValue ▷ l)
| .list l => return .list (<- RValue ▷ l)
| .expr e@(.call _ _ _) ty => do
let v := (<- genName).toString
add_stmt (.assign v e)
Expand Down Expand Up @@ -197,45 +197,37 @@ mutual
partial def expr : Expr -> Tracer Item
| .exprPos e' p => withPos p (expr' e')

partial def term (e : Expr) : Tracer Term := do
match (<- expr e) with
| .module n => return .expr (.var n.toString) (.any "?".toName)
| .global g => return .expr (.var g.name.toString) (.any "?".toName)
| .source _ => throw "invalid use of source function"
| .term t => return t
partial def term (e : Expr) : Tracer Term :=
return <- (<- expr e).toTerm

partial def term' (e : Expr') : Tracer Term := do
term (.exprPos e (<- getPos))
partial def term' (e : Expr') : Tracer Term :=
return <- (<- expr' e).toTerm

partial def klr (e : Expr) : Tracer KLR.Expr := do
match (<- term e) with
| .object obj => return .var obj.name.toString
| .tuple _ => throw "tuple cannot be converted to a KLR term"
| .list _ => throw "list cannot be converted to a KLR term"
| .expr e _ => return e
partial def klr (e : Expr) : Tracer KLR.Expr :=
return <- (<- term e).toKLR

partial def integer (e : Expr) : Tracer Int := do
match (<- term e) with
| .expr (.const c) _ => return (<- c.toInt)
| _ => throw "invalid tensor dimension"
match <- klr e with
| .const c => return (<- c.toInt)
| _ => throw "expecting integer"

partial def expr' : Expr' -> Tracer Item
| .const c => return .term (<- const c)
| .tensor s dty => do
let shape <- s.mapM integer
let shape <- integer ▷ s
let name <- genName "t".toName
return .term (.expr (.tensor ⟨ name.toString, dty, shape ⟩) (.tensor dty shape))
| .name id _ => lookup_item id.toName
| .attr (.exprPos e p) id _ => do withPos p ((<- expr' e).attr id)
| .tuple l _ => return .term (.tuple (<- l.mapM term))
| .list l _ => return .term (.list (<- l.mapM term))
| .tuple l _ => return .term (.tuple (<- term ▷ l))
| .list l _ => return .term (.list (<- term ▷ l))
| .subscript t [ .exprPos (.tuple ix _) _ ] _
| .subscript t ix _ => return .term (<- access (<- term t) (<- ix.mapM index))
| .subscript t ix _ => return .term (<- access (<- term t) (<- index ▷ ix))
| .slice _ _ _ => throw "syntax error"
| .boolOp op xs => return .term (<- boolOp op (<- xs.mapM term))
| .boolOp op xs => return .term (<- boolOp op (<- term ▷ xs))
| .binOp op l r => return .term (<- binOp op (<- term l) (<- term r))
| .unaryOp op e => return .term (<- unOp op (<- term e))
| .compare l ops cs => return .term (<- compare (<- term l) ops (<- cs.mapM term))
| .compare l ops cs => return .term (<- compare (<- term l) ops (<- term ▷ cs))
| .ifExp tst tru fls => do
let tst <- (<- term tst).isTrue
let tru <- expr tru -- eagerly evaluate both branches
Expand All @@ -244,10 +236,10 @@ partial def expr' : Expr' -> Tracer Item
| .call f args kws => do
match <- expr f with
| .module n => throw s!"module {n} not callable"
| .global g => return .term (<- g.call (<- args.mapM term) (<- kws.mapM (keyword term)))
| .term t => return .term (<- t.call (<- args.mapM klr) (<- kws.mapM (keyword klr)))
| .global g => return .term (<- g.call (<- term ▷ args) (<- keyword term ▷ kws))
| .term t => return .term (<- t.call (<- klr ▷ args) (<- keyword klr ▷ kws))
| .source f => do
function_call f (<- args.mapM term) (<- kws.mapM (keyword term))
function_call f (<- term ▷ args) (<- keyword term ▷ kws)
return .term (.expr (.const .none) .none)

partial def keyword (f : Expr -> Tracer a) : Keyword -> Tracer (String × a)
Expand All @@ -263,7 +255,7 @@ partial def stmt' : Stmt' -> Tracer Unit
| .assert e => do
let t <- term e
if (<- t.isFalse) then throw "assertion failed"
| .assign xs e => do assign (<- xs.mapM LValue) (<- term e)
| .assign xs e => do assign (<- LValue ▷ xs) (<- term e)
| .augAssign x op e => do
stmt' (.assign [x] (.exprPos (.binOp op x e) (<- getPos)))
| .annAssign _ _ .none => return ()
Expand Down Expand Up @@ -349,7 +341,7 @@ def traceKernel (k : Kernel) : Tracer Unit := do
let kwargs <- k.kwargs.mapM fun (x,e) => return (x, <- term' e)
function_call f args kwargs

def runKernel (k : Kernel) : Except String (List KLR.Stmt) :=
def runKernel (k : Kernel) : Err (List KLR.Stmt) :=
tracer ⟨ ∅, #[] ⟩ do
traceKernel k
let g <- get
Expand Down
Loading

0 comments on commit c65ccb5

Please sign in to comment.