-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: basic definitions to support tracing
- Loading branch information
Showing
3 changed files
with
375 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
/- | ||
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
Released under Apache 2.0 license as described in the file LICENSE. | ||
Authors: Paul Govereau | ||
-/ | ||
import NKL.Trace.Types | ||
--import NKL.Trace.Basic | ||
import NKL.Trace.Builtin | ||
--import NKL.Trace.Python |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
/- | ||
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
Released under Apache 2.0 license as described in the file LICENSE. | ||
Authors: Paul Govereau | ||
-/ | ||
import NKL.KLR | ||
import NKL.Trace.Types | ||
|
||
/- | ||
# Utilities for creating Builtins and Globals | ||
-/ | ||
|
||
namespace NKL.Trace | ||
open NKL.KLR | ||
|
||
abbrev BuiltinAttr := String -> ErrorM Term | ||
abbrev GlobalAttr := String -> TraceM Term | ||
|
||
abbrev BuiltinFn := List Expr -> List (String × Expr) -> ErrorM Term | ||
abbrev GlobalFn := List Term -> List (String × Term) -> TraceM Term | ||
|
||
def noattrs [Monad m] [MonadExcept String m] : Name -> String -> m a := | ||
fun name attr => throw s!"{attr} is not an attribute of {name}" | ||
|
||
def uncallable [Monad m] [MonadExcept String m] : Name -> a -> b -> m c := | ||
fun name _ _ => throw s!"{name} is not a callable type" | ||
|
||
-- Create a built-in representing a function; no attributes supported. | ||
|
||
def simple_function (name : Name) (f : BuiltinFn) : Builtin := | ||
{ name := name | ||
, type := .any name | ||
, attr := noattrs name | ||
, call := f | ||
} | ||
|
||
-- Create a built-in representing a function; basic attributes supported. | ||
|
||
def python_function (name : Name) (f : BuiltinFn) : Builtin := | ||
{ name := name | ||
, type := .any name | ||
, attr := attrs | ||
, call := f | ||
} | ||
where | ||
attrs : BuiltinAttr | ||
| "__name__" => return .expr (.const $ .string name.toString) .string | ||
| "__call__" => return .object (simple_function name f) | ||
| a => throw s!"unsupported attribute {a}" | ||
|
||
-- Create a built-in representing a simple object from a list of attributes. | ||
|
||
def simple_object {a : Type} | ||
(name : Name) | ||
(attrs : List (String × (a -> BuiltinFn))) | ||
(x : a) : Builtin := | ||
{ name := name | ||
, type := .none | ||
, attr := attr_fn | ||
, call := uncallable name | ||
} | ||
where | ||
attr_fn (attr : String) : Except String Term := | ||
match attrs.find? (fun x => x.fst == attr) with | ||
| none => .error s!"{attr} is not an attribute of {name}" | ||
| some (_,fn) => .ok (.object $ simple_function (name.str attr) (fn x)) | ||
|
||
|
||
-- Basic Python types could be represented as built-ins. In practice, we don't | ||
-- do this as it is more convenient to have tuples, etc. represented directly | ||
-- in the `Term` type. However, as an example, the tuple class may be defined | ||
-- similar to below. | ||
|
||
-- Note: not used | ||
def tuple_obj : List Term -> Builtin := | ||
simple_object "tuple".toName | ||
[ ("count", fun l _ _ => .ok (.expr (.const $ .int l.length) .int)) | ||
, ("index", index_fn) | ||
] | ||
where | ||
index_fn : List Term -> BuiltinFn | ||
| l, [x], [] => match l.indexOf? (.expr x .none) with | ||
| none => throw s!"{repr x} not in tuple" | ||
| some i => return .expr (.const $ .int i) .int | ||
| _, _, _ => throw "invalid arguments" | ||
|
||
-- Note: not used | ||
def tuple_class : Global := | ||
let name := "class_tuple".toName | ||
{ name := name | ||
, attr := noattrs name | ||
, call := make_tuple | ||
} | ||
where | ||
make_tuple : GlobalFn | ||
| args, [] => return .object (tuple_obj args) | ||
| _, _ => throw "invalid arguments" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,268 @@ | ||
/- | ||
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
Released under Apache 2.0 license as described in the file LICENSE. | ||
Authors: Paul Govereau | ||
-/ | ||
import Lean | ||
import NKL.KLR | ||
import NKL.Python | ||
|
||
/- | ||
# Basic types for tracing | ||
Tracing is a special form of partial evaluation. After parsing the original | ||
python terms, they are "traced" to produce simpler KLR terms. Therefore the | ||
input to tracing is a `NKL.Python` AST, and the output is a `NKL.KLR` AST. The | ||
tracing process introduces an intermediate AST, called `Term` which is an | ||
extension of the `KLR.Expr` type. The `Term` type is used to represent | ||
sub-expressions that are in the process of being traced, but not yet | ||
representable in KLR. For example, consider the statement: | ||
a, b = t.shape | ||
Here, the left-hand side is a tuple of variables (which cannot be represented in | ||
KLR), and the right-hand side is a built-in attribute of a tensor. During | ||
tracing, the expression elaborator needs to return a tuple of variables `(a,b)`, | ||
and the built-in `shape` attribute needs to return a tuple of integers. With | ||
both of these in hand, the statement elaborator can expand the tuples into two | ||
KLR assignment statements (or generate an error). The intermediate tuples are | ||
elements of the type `Term`, and the final statements are elements of the type | ||
`KLR.Stmt`. | ||
Tracing takes place within a pair of nested state monads called `TraceM` (the | ||
inner one), and `Tracer` (the outer one). Most code only needs to use the | ||
`TraceM` monad (more explanation of `Tracer`, and why we need it, is given later | ||
in this file). The `TraceM` monad provides access to an environment which | ||
contains bindings for all of the local variables currently in scope. | ||
All local variables refer to `Term`s and hence may not be fully reduced. At the | ||
completion of tracing, all terms must be reducible to `KLR.Expr`s or an error is | ||
generated. This requirement is due to an implicit phase separation in the design | ||
of NKI: some terms must be eliminated by tracing, and some terms can be passed | ||
to the compiler. KLR only represents terms which can be passed on to the | ||
compiler. For example, assertions have to be resolved at tracing time, neither | ||
the compiler nor the hardware can raise an assertion that the user can see | ||
during development. Hence, KLR does not have an assert statement, and any | ||
expressions under an assert must be resolved during tracing. Other examples are | ||
conditional expressions, and loop bounds; both of which must be resolved during | ||
tracing. | ||
In addition to `Term`s, the environment can contain built-in objects. Built-in | ||
objects are defined using Lean functions that generate either other built-ins or | ||
terms. For example, accessing an attribute of a built-in object may produce a | ||
built-in function, which, when called, generates a KLR expression. There are | ||
also built-ins at the `Tracer` monad level, which can generate KLR statements as | ||
well as modify the environment of the `TraceM` monad. | ||
This module defines types to represent the built-ins, the environments, and the | ||
tracing monads. | ||
-/ | ||
|
||
namespace NKL.Trace | ||
open NKL.KLR | ||
|
||
-- Lean already has a perfectly nice hierarchical string type | ||
export Lean (Name) | ||
deriving instance Ord for Name | ||
|
||
abbrev ErrorM := Except String | ||
|
||
-- Terms are an extension of KLR.Expr, and they have types, which may be `any`. | ||
-- TODO: can we get rid of any? | ||
|
||
inductive TermType where | ||
| none | bool | int | float | string | ||
| any : Name -> TermType | ||
| tuple : List TermType -> TermType | ||
| list : List TermType -> TermType | ||
| tensor : Dtype -> Shape -> TermType | ||
deriving Repr, BEq | ||
|
||
mutual | ||
structure Builtin where | ||
name : Name | ||
type : TermType | ||
attr : String -> Except String Term | ||
call : List Expr -> List (String × Expr) -> Except String Term | ||
|
||
inductive Term where | ||
| object : Builtin -> Term | ||
| tuple : List Term -> Term | ||
| list : List Term -> Term | ||
| expr : Expr -> TermType -> Term | ||
end | ||
|
||
instance : Repr Term where | ||
reprPrec b n := match b with | ||
| .object obj => .text s!"object<{obj.name}>" | ||
| .tuple l => .text s!"tuple<{l.length}>" | ||
| .list l => .text s!"list<{l.length}>" | ||
| .expr e ty => reprPrec e n ++ ":" ++ reprPrec ty n | ||
|
||
def Term.beq : Term -> Term -> Bool | ||
| .tuple l , .tuple r => lst_eq l r | ||
| .list l , .list r => lst_eq l r | ||
| .expr l _, .expr r _ => l == r | ||
| _, _ => false | ||
where | ||
lst_eq : List Term -> List Term -> Bool | ||
| [] , [] => true | ||
| l :: ls, r :: rs => Term.beq l r && lst_eq ls rs | ||
| _, _ => false | ||
|
||
instance : BEq Term where beq := Term.beq | ||
|
||
partial def Term.type : Term -> ErrorM TermType | ||
| .object obj => return obj.type | ||
| .tuple l => return .tuple (<- l.mapM Term.type) | ||
| .list l => return .tuple (<- l.mapM Term.type) | ||
| .expr _ ty => return ty | ||
|
||
-- Our state has a number for generating fresh names, the current source | ||
-- location (for error reporting), and the local environment. The global | ||
-- environment is in the `Tracer` monad (below). | ||
|
||
export NKL.Python (Pos) | ||
abbrev Env := Lean.RBMap Name Term compare | ||
|
||
structure State where | ||
fvn : Nat := 0 | ||
pos : Pos := { } | ||
env : Env := ∅ | ||
|
||
namespace State | ||
|
||
def ofList (l : List (String × Term)) : State := | ||
{ env := Lean.RBMap.ofList $ l.map fun (s,i) => (s.toName, i) } | ||
|
||
def contains (s : State) (n : Name) : Bool := | ||
s.env.contains n | ||
|
||
end State | ||
|
||
abbrev TraceM := EStateM String State | ||
|
||
instance : MonadLift ErrorM TraceM where | ||
monadLift | ||
| .ok x => return x | ||
| .error s => throw s | ||
|
||
-- Run a trace with an empty initial environment | ||
def trace (m : TraceM a) : ErrorM a := | ||
match m.run { } with | ||
| .ok x _ => return x | ||
| .error s _ => throw s | ||
|
||
-- generate a fresh name using an existing name as a prefix | ||
def genName (name : Name := .anonymous) : TraceM Name := do | ||
let s <- get | ||
let n := s.fvn + 1 | ||
set { s with fvn := n } | ||
return .num name n | ||
|
||
-- add a new binding to the local environment | ||
def extend (x : Name) (v : Term) : TraceM Unit := | ||
modify fun s => {s with env := s.env.insert x v} | ||
|
||
-- lookup a name in the local environment | ||
def lookup? (name : Name) : TraceM (Option Term) := do | ||
let s <- get | ||
return s.env.find? name | ||
|
||
def lookup (name : Name) : TraceM Term := do | ||
match (<- get).env.find? name with | ||
| none => throw s!"{name} not found" | ||
| some x => return x | ||
|
||
|
||
-- The Tracer monad is setup similar to the TraceM monad, but "one level up". | ||
-- Built-ins in Tracer can operate over Terms and within the TraceM monad, hence | ||
-- they can query and modify the state of the inner monad. | ||
|
||
-- The global environment contains module declarations, global objects, python | ||
-- sources of user kernels, and possibly `Term`s. A Term may appear in the | ||
-- global environment as a convient way to represent global constants, such as | ||
-- `nki.language.psum`. Only the core tracing code needs to use the `Tracer` | ||
-- monad. | ||
|
||
-- The global environment also track the current source position for error | ||
-- reporting, and keeps the list of translated statements for the current | ||
-- kernel. | ||
|
||
structure Global where | ||
name : Name | ||
attr : String -> TraceM Term | ||
call : List Term -> List (String × Term) -> TraceM Term | ||
|
||
inductive Item where | ||
| module : Name -> Item | ||
| global : Global -> Item | ||
| source : Python.Fun -> Item | ||
| term : Term -> Item | ||
|
||
def Item.type : Item -> ErrorM TermType | ||
| .module n => return .any n | ||
| .global g => return .any g.name | ||
| .source _ => return .any "source".toName | ||
| .term t => t.type | ||
|
||
structure Globals where | ||
env : Lean.RBMap Name Item compare | ||
body : Array Stmt | ||
|
||
-- The outer monad | ||
abbrev Tracer := StateT Globals TraceM | ||
|
||
def getState : Tracer State := (get : TraceM State) | ||
def setState (s : State) : Tracer Unit := set s | ||
|
||
def getPos : Tracer Pos := | ||
fun g s => .ok (s.pos, g) s | ||
|
||
def withPos (p : Pos) (m : Tracer a) : Tracer a := | ||
fun g s => m g { s with pos := p } | ||
|
||
def withSrc (source : String) (m : Tracer a) : Tracer a := | ||
try withPos {} m | ||
catch e => do | ||
let pos <- getPos | ||
throw (Python.Parsing.genError source e pos) | ||
|
||
-- Enter a new scope, replacing the local state on exit. Note: we preserve the | ||
-- position in the error case so the error handler (above) will get the | ||
-- correct position on error. This should be fine since we are setting the | ||
-- position on every statement and expression while traversing the tree. | ||
|
||
def enter (m : Tracer a) : Tracer a := | ||
fun g s => match m g s with | ||
| .ok (x, g') _ => .ok (x, g') s | ||
| .error err s' => .error err { s with pos := s'.pos } | ||
|
||
-- Enter a new function scope, removing all local bindings | ||
|
||
def enterFun (m : Tracer a) : Tracer a := | ||
enter $ do | ||
let s <- getState | ||
setState { s with env := ∅ } | ||
m | ||
|
||
def extend_global (name : Name) (i : Item) : Tracer Unit := | ||
modify fun s => { s with env := s.env.insert name i } | ||
|
||
def lookup_global (name : Name) : Tracer Item := do | ||
match (<- get).env.find? name with | ||
| none => throw s!"{name} not found" | ||
| some x => return x | ||
|
||
def lookup_item (name : Name) : Tracer Item := do | ||
match (<- lookup? name) with | ||
| some x => return .term x | ||
| none => lookup_global name | ||
|
||
def add_stmt (s : Stmt) : Tracer Unit := | ||
modify fun g => { g with body := g.body.push s } | ||
|
||
def tracer (g : Globals) (m : Tracer a) : ErrorM a := | ||
match trace (m.run g) with | ||
| .ok x => .ok x.fst | ||
| .error s => .error s |