From 512371f3c42f719182ea8db66c618c4677371d18 Mon Sep 17 00:00:00 2001 From: Paul Govereau Date: Sun, 26 Jan 2025 09:25:37 -0500 Subject: [PATCH] refactor: make some partial functions total This patch cleans up a few cases of partial definitions, making them total. There is also a small amount of refactoring and renaming to be consistent with TensorLib. --- NKL/FFI.lean | 4 +-- NKL/KLR/Basic.lean | 32 ---------------------- NKL/KLR/Encode.lean | 11 ++++---- NKL/Python.lean | 10 ++----- NKL/Trace.lean | 2 +- NKL/Trace/Basic.lean | 36 ++++++++++++++++++++++-- NKL/Trace/Builtin.lean | 6 ++-- NKL/Trace/Python.lean | 62 ++++++++++++++++++------------------------ NKL/Trace/Types.lean | 38 +++++++++++++++----------- NKL/Util.lean | 47 ++++++++++++++++++++++++++++++++ 10 files changed, 145 insertions(+), 103 deletions(-) create mode 100644 NKL/Util.lean diff --git a/NKL/FFI.lean b/NKL/FFI.lean index d8e7fe1..28f800f 100644 --- a/NKL/FFI.lean +++ b/NKL/FFI.lean @@ -11,7 +11,7 @@ import NKL.Trace namespace NKL open NKL.KLR -local instance : MonadLift (Except String) IO where +local instance : MonadLift Err IO where monadLift | .ok x => return x | .error s => throw $ .userError s @@ -21,4 +21,4 @@ def parse_json (s : String) : IO Unit := do let kernel <- Python.Parsing.parse s let stmts <- NKL.Trace.runNKIKernel kernel for s in stmts do - IO.println (" " ++ Lean.format s) --s!"{s}\n{repr s}" + IO.println (" " ++ Lean.format s) diff --git a/NKL/KLR/Basic.lean b/NKL/KLR/Basic.lean index fab5989..76b7bd6 100644 --- a/NKL/KLR/Basic.lean +++ b/NKL/KLR/Basic.lean @@ -45,38 +45,6 @@ inductive Const where | string (value : String) deriving Repr, BEq -namespace Const - --- Python-like rules for conversion to boolean -def isTrue : Const -> Bool - | .none => false - | .bool b => b - | .int i => i != 0 - | .float f => f != 0.0 - | .string s => s != "" - --- Python-like rules for conversion to integer -def toInt : Const -> Except String Int - | .none => throw "none cannot be converted to an integer" - | .bool true => return 1 - | .bool false => return 0 - | .int i => return i - | .float f => - -- Python is a bit strange here, it truncates both - -- positive and negative numbers toward zero - if f < 0.0 then - return (Int.ofNat (Float.floor (-f)).toUInt64.toNat).neg - else - return Int.ofNat (Float.floor f).toUInt64.toNat - | .string s => - -- Fortunately, Lean's String.toInt appears to be compatible - -- with Python's int(string) conversion. - match s.toInt? with - | .none => throw s!"string {s} cannot be converted to an integer" - | .some i => return i - -end Const - -- This corresponds to the "Quasi-Affine Expressions" in Neuron. -- Note, `floor` is the usual integer division. inductive IndexExpr where diff --git a/NKL/KLR/Encode.lean b/NKL/KLR/Encode.lean index c5336a1..320f72f 100644 --- a/NKL/KLR/Encode.lean +++ b/NKL/KLR/Encode.lean @@ -3,6 +3,7 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Paul Govereau -/ +import NKL.Util import NKL.KLR.Basic /-! @@ -12,15 +13,15 @@ import NKL.KLR.Basic namespace NKL.KLR --- All of the encode function are pure; decoding uses an instance of EStateM. +-- All of the encode function are pure; decoding uses an instance of StM. -abbrev DecodeM := EStateM String ByteArray.Iterator +abbrev DecodeM := StM ByteArray.Iterator def decode' (f : DecodeM a) (ba : ByteArray) : Option a := - EStateM.run' f ba.iter + f.run' ba.iter -def decode (f : DecodeM a) (ba : ByteArray) : Except String a := - match EStateM.run f ba.iter with +def decode (f : DecodeM a) (ba : ByteArray) : Err a := + match f.run ba.iter with | .ok x _ => .ok x | .error s _ => .error s diff --git a/NKL/Python.lean b/NKL/Python.lean index b2aedd0..9c47552 100644 --- a/NKL/Python.lean +++ b/NKL/Python.lean @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Paul Govereau -/ import Lean +import NKL.Util /-! # Abstract syntax of Python functions @@ -172,12 +173,7 @@ open Lean -- span (Pos) is saved while traversing the tree to identify the location -- of any errors in the original program -abbrev Parser := EStateM String Pos - -local instance : MonadLift (Except String) Parser where - monadLift - | .ok x => return x - | .error s => throw s +abbrev Parser := StM Pos private def str : Json -> Parser String := monadLift ∘ Json.getStr? @@ -364,7 +360,7 @@ def kernel (j : Json) : Parser Kernel := do let globals <- field (dict global) j "globals" return Kernel.mk name funcs args kwargs globals -def parse (s : String) : Except String Kernel := do +def parse (s : String) : Err Kernel := do let jsn <- Json.parse s match kernel jsn {} with | .ok x _ => .ok x diff --git a/NKL/Trace.lean b/NKL/Trace.lean index 481a552..1f19403 100644 --- a/NKL/Trace.lean +++ b/NKL/Trace.lean @@ -13,7 +13,7 @@ import NKL.Trace.NKI namespace NKL.Trace -def runNKIKernel (k : NKL.Python.Kernel) : Except String (List NKL.KLR.Stmt) := +def runNKIKernel (k : NKL.Python.Kernel) : Err (List NKL.KLR.Stmt) := tracer ⟨ .ofList NKIEnv, #[] ⟩ do traceKernel k let g <- get diff --git a/NKL/Trace/Basic.lean b/NKL/Trace/Basic.lean index 59d1f7d..83167a5 100644 --- a/NKL/Trace/Basic.lean +++ b/NKL/Trace/Basic.lean @@ -13,12 +13,44 @@ import NKL.Trace.Tensor Basic tracing definitions only deal with Terms (not Python sources) -/ +namespace NKL.KLR.Const + +-- Python-like rules for conversion to boolean +def isTrue : Const -> Bool + | .none => false + | .bool b => b + | .int i => i != 0 + | .float f => f != 0.0 + | .string s => s != "" + +-- Python-like rules for conversion to integer +def toInt : Const -> Err Int + | .none => throw "none cannot be converted to an integer" + | .bool true => return 1 + | .bool false => return 0 + | .int i => return i + | .float f => + -- Python is a bit strange here, it truncates both + -- positive and negative numbers toward zero + if f < 0.0 then + return (Int.ofNat (Float.floor (-f)).toUInt64.toNat).neg + else + return Int.ofNat (Float.floor f).toUInt64.toNat + | .string s => + -- Fortunately, Lean's String.toInt appears to be compatible + -- with Python's int(string) conversion. + match s.toInt? with + | .none => throw s!"string {s} cannot be converted to an integer" + | .some i => return i + +end NKL.KLR.Const + namespace NKL.Trace open NKL.KLR -- Operators within index expressions -def indexBinOp : String -> KLR.IndexExpr -> KLR.IndexExpr -> ErrorM KLR.IndexExpr +def indexBinOp : String -> KLR.IndexExpr -> KLR.IndexExpr -> Err KLR.IndexExpr | "Add" , l, r => return .add l r | "Sub" , l, r => return .add l r.neg | "Mult", .int i, e @@ -27,7 +59,7 @@ def indexBinOp : String -> KLR.IndexExpr -> KLR.IndexExpr -> ErrorM KLR.IndexExp | "Mod" , e, .int i => return .mod e i | _, _, _ => throw "invalid index expression" -def indexUnOp : String -> KLR.IndexExpr -> ErrorM KLR.IndexExpr +def indexUnOp : String -> KLR.IndexExpr -> Err KLR.IndexExpr | "USub", e => return .neg e | _, _ => throw "invalid index expresssion" diff --git a/NKL/Trace/Builtin.lean b/NKL/Trace/Builtin.lean index d581ece..3233bc0 100644 --- a/NKL/Trace/Builtin.lean +++ b/NKL/Trace/Builtin.lean @@ -14,10 +14,10 @@ import NKL.Trace.Types namespace NKL.Trace open NKL.KLR -abbrev BuiltinAttr := String -> ErrorM Term +abbrev BuiltinAttr := String -> Err Term abbrev GlobalAttr := String -> TraceM Term -abbrev BuiltinFn := List Expr -> List (String × Expr) -> ErrorM Term +abbrev BuiltinFn := List Expr -> List (String × Expr) -> Err Term abbrev GlobalFn := List Term -> List (String × Term) -> TraceM Term def noattrs [Monad m] [MonadExcept String m] : Name -> String -> m a := @@ -61,7 +61,7 @@ def simple_object {a : Type} , call := uncallable name } where - attr_fn (attr : String) : Except String Term := + attr_fn (attr : String) : Err Term := match attrs.find? (fun x => x.fst == attr) with | none => .error s!"{attr} is not an attribute of {name}" | some (_,fn) => .ok (.object $ simple_function (name.str attr) (fn x)) diff --git a/NKL/Trace/Python.lean b/NKL/Trace/Python.lean index 0394453..85106bf 100644 --- a/NKL/Trace/Python.lean +++ b/NKL/Trace/Python.lean @@ -12,7 +12,7 @@ import NKL.Trace.Basic namespace NKL.Trace open NKL.Python -def const : Const -> ErrorM Term +def const : Const -> Err Term | .none => return .expr (.const $ .none) .none | .bool b => return .expr (.const $ .bool b) .bool | .int i => return .expr (.const $ .int i) .int @@ -132,21 +132,21 @@ dead-code elimination for simple assignments. -/ -- Convert an expression in assignment context (an L-Value). -partial def LValue : Expr -> Tracer Term +def LValue : Expr -> Tracer Term | .exprPos e' p => withPos p (lval e') where lval : Expr' -> Tracer Term | .name id .store => return .expr (.var id) (.any "?".toName) - | .tuple l .store => return .tuple (<- l.mapM LValue) - | .list l .store => return .list (<- l.mapM LValue) + | .tuple l .store => return .tuple (<- LValue ▷ l) + | .list l .store => return .list (<- LValue ▷ l) | _ => throw "cannot assign to expression" -- Convert an R-Value to a pure expression, emitting -- additional assignments as needed. -partial def RValue : Term -> Tracer Term +def RValue : Term -> Tracer Term | .object o => return .object o - | .tuple l => return .tuple (<- l.mapM RValue) - | .list l => return .list (<- l.mapM RValue) + | .tuple l => return .tuple (<- RValue ▷ l) + | .list l => return .list (<- RValue ▷ l) | .expr e@(.call _ _ _) ty => do let v := (<- genName).toString add_stmt (.assign v e) @@ -197,45 +197,37 @@ mutual partial def expr : Expr -> Tracer Item | .exprPos e' p => withPos p (expr' e') -partial def term (e : Expr) : Tracer Term := do - match (<- expr e) with - | .module n => return .expr (.var n.toString) (.any "?".toName) - | .global g => return .expr (.var g.name.toString) (.any "?".toName) - | .source _ => throw "invalid use of source function" - | .term t => return t +partial def term (e : Expr) : Tracer Term := + return <- (<- expr e).toTerm -partial def term' (e : Expr') : Tracer Term := do - term (.exprPos e (<- getPos)) +partial def term' (e : Expr') : Tracer Term := + return <- (<- expr' e).toTerm -partial def klr (e : Expr) : Tracer KLR.Expr := do - match (<- term e) with - | .object obj => return .var obj.name.toString - | .tuple _ => throw "tuple cannot be converted to a KLR term" - | .list _ => throw "list cannot be converted to a KLR term" - | .expr e _ => return e +partial def klr (e : Expr) : Tracer KLR.Expr := + return <- (<- term e).toKLR partial def integer (e : Expr) : Tracer Int := do - match (<- term e) with - | .expr (.const c) _ => return (<- c.toInt) - | _ => throw "invalid tensor dimension" + match <- klr e with + | .const c => return (<- c.toInt) + | _ => throw "expecting integer" partial def expr' : Expr' -> Tracer Item | .const c => return .term (<- const c) | .tensor s dty => do - let shape <- s.mapM integer + let shape <- integer ▷ s let name <- genName "t".toName return .term (.expr (.tensor ⟨ name.toString, dty, shape ⟩) (.tensor dty shape)) | .name id _ => lookup_item id.toName | .attr (.exprPos e p) id _ => do withPos p ((<- expr' e).attr id) - | .tuple l _ => return .term (.tuple (<- l.mapM term)) - | .list l _ => return .term (.list (<- l.mapM term)) + | .tuple l _ => return .term (.tuple (<- term ▷ l)) + | .list l _ => return .term (.list (<- term ▷ l)) | .subscript t [ .exprPos (.tuple ix _) _ ] _ - | .subscript t ix _ => return .term (<- access (<- term t) (<- ix.mapM index)) + | .subscript t ix _ => return .term (<- access (<- term t) (<- index ▷ ix)) | .slice _ _ _ => throw "syntax error" - | .boolOp op xs => return .term (<- boolOp op (<- xs.mapM term)) + | .boolOp op xs => return .term (<- boolOp op (<- term ▷ xs)) | .binOp op l r => return .term (<- binOp op (<- term l) (<- term r)) | .unaryOp op e => return .term (<- unOp op (<- term e)) - | .compare l ops cs => return .term (<- compare (<- term l) ops (<- cs.mapM term)) + | .compare l ops cs => return .term (<- compare (<- term l) ops (<- term ▷ cs)) | .ifExp tst tru fls => do let tst <- (<- term tst).isTrue let tru <- expr tru -- eagerly evaluate both branches @@ -244,10 +236,10 @@ partial def expr' : Expr' -> Tracer Item | .call f args kws => do match <- expr f with | .module n => throw s!"module {n} not callable" - | .global g => return .term (<- g.call (<- args.mapM term) (<- kws.mapM (keyword term))) - | .term t => return .term (<- t.call (<- args.mapM klr) (<- kws.mapM (keyword klr))) + | .global g => return .term (<- g.call (<- term ▷ args) (<- keyword term ▷ kws)) + | .term t => return .term (<- t.call (<- klr ▷ args) (<- keyword klr ▷ kws)) | .source f => do - function_call f (<- args.mapM term) (<- kws.mapM (keyword term)) + function_call f (<- term ▷ args) (<- keyword term ▷ kws) return .term (.expr (.const .none) .none) partial def keyword (f : Expr -> Tracer a) : Keyword -> Tracer (String × a) @@ -263,7 +255,7 @@ partial def stmt' : Stmt' -> Tracer Unit | .assert e => do let t <- term e if (<- t.isFalse) then throw "assertion failed" - | .assign xs e => do assign (<- xs.mapM LValue) (<- term e) + | .assign xs e => do assign (<- LValue ▷ xs) (<- term e) | .augAssign x op e => do stmt' (.assign [x] (.exprPos (.binOp op x e) (<- getPos))) | .annAssign _ _ .none => return () @@ -349,7 +341,7 @@ def traceKernel (k : Kernel) : Tracer Unit := do let kwargs <- k.kwargs.mapM fun (x,e) => return (x, <- term' e) function_call f args kwargs -def runKernel (k : Kernel) : Except String (List KLR.Stmt) := +def runKernel (k : Kernel) : Err (List KLR.Stmt) := tracer ⟨ ∅, #[] ⟩ do traceKernel k let g <- get diff --git a/NKL/Trace/Types.lean b/NKL/Trace/Types.lean index 5a18403..8328583 100644 --- a/NKL/Trace/Types.lean +++ b/NKL/Trace/Types.lean @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Paul Govereau -/ import Lean +import NKL.Util import NKL.KLR import NKL.Python @@ -65,8 +66,6 @@ open NKL.KLR export Lean (Name) deriving instance Ord for Name -abbrev ErrorM := Except String - -- Terms are an extension of KLR.Expr, and they have types, which may be `any`. -- TODO: can we get rid of any? @@ -84,8 +83,8 @@ mutual structure Object where name : Name type : TermType - attr : String -> Except String Term - call : List Expr -> List (String × Expr) -> Except String Term + attr : String -> Err Term + call : List Expr -> List (String × Expr) -> Err Term inductive Term where | object : Object -> Term @@ -114,12 +113,18 @@ where instance : BEq Term where beq := Term.beq -partial def Term.type : Term -> ErrorM TermType +def Term.type : Term -> Err TermType | .object obj => return obj.type - | .tuple l => return .tuple (<- l.mapM Term.type) - | .list l => return .tuple (<- l.mapM Term.type) + | .tuple l => return .tuple (<- Term.type ▷ l) + | .list l => return .tuple (<- Term.type ▷ l) | .expr _ ty => return ty +def Term.toKLR : Term -> Err KLR.Expr + | .object obj => return .var obj.name.toString + | .tuple _ => throw "tuple cannot be converted to a KLR term" + | .list _ => throw "list cannot be converted to a KLR term" + | .expr e _ => return e + -- Our state has a number for generating fresh names, the current source -- location (for error reporting), and the local environment. The global -- environment is in the `Tracer` monad (below). @@ -142,15 +147,10 @@ def contains (s : State) (n : Name) : Bool := end State -abbrev TraceM := EStateM String State - -instance : MonadLift ErrorM TraceM where - monadLift - | .ok x => return x - | .error s => throw s +abbrev TraceM := StM State -- Run a trace with an empty initial environment -def trace (m : TraceM a) : ErrorM a := +def trace (m : TraceM a) : Err a := match m.run { } with | .ok x _ => return x | .error s _ => throw s @@ -221,12 +221,18 @@ inductive Item where | source : Python.Fun -> Item | term : Term -> Item -def Item.type : Item -> ErrorM TermType +def Item.type : Item -> Err TermType | .module n => return .any n | .global g => return .any g.name | .source _ => return .any "source".toName | .term t => t.type +def Item.toTerm : Item -> Err Term + | .module n => return .expr (.var n.toString) (.any "?".toName) + | .global g => return .expr (.var g.name.toString) (.any "?".toName) + | .source _ => throw "invalid use of source function" + | .term t => return t + structure Globals where env : Lean.RBMap Name Item compare body : Array Stmt @@ -285,7 +291,7 @@ def lookup_item (name : Name) : Tracer Item := do def add_stmt (s : Stmt) : Tracer Unit := modify fun g => { g with body := g.body.push s } -def tracer (g : Globals) (m : Tracer a) : ErrorM a := +def tracer (g : Globals) (m : Tracer a) : Err a := match trace (m.run g) with | .ok x => .ok x.fst | .error s => .error s diff --git a/NKL/Util.lean b/NKL/Util.lean new file mode 100644 index 0000000..385a305 --- /dev/null +++ b/NKL/Util.lean @@ -0,0 +1,47 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Paul Govereau +-/ + +-- Common Utilities + +namespace NKL + +/- +The default choice for an error monad is `Except String`, used for simple +computations that can fail. + +This is defined as a notation so that it can be used within mutually recursive +inductive types without issues. (abbrev introduces a new definition which +cannot be used in a mutually recursive inductive) +-/ +notation "Err" => Except String + +/- +The default choice for a state monad is `EStateM String`. +Again, we use a notation for the same reason as for `Err`. + +Provide automatic lifting of Err, for any state monad instance. +-/ +notation "StM" => EStateM String + +instance : MonadLift Err (StM a) where + monadLift + | .ok x => .ok x + | .error s => .error s + +/- +A common issue is failure to prove termination automatically when using +List.mapM. There is a work-around for this which involves introducing +`{ x // x ∈ l }` in place of the list `l`. + +We can capture this trick in a notation. Note we need to use a notation and not +a definition because the proof object `x∈l` needs to be available to the +termination proof tactics, in the scope of the original function. + +Writing, `List.mapM f l`, as `f ▷ l` doesn't break the termination proof. +Note: ▷ is typed as \rhd +-/ +notation f "▷" l => + List.mapM (fun ⟨ x, _ ⟩ => f x) (List.attach l)