Skip to content

Commit

Permalink
feat: basic definitions to support tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
govereau committed Jan 9, 2025
1 parent e4dc75f commit aad52cc
Show file tree
Hide file tree
Showing 3 changed files with 375 additions and 0 deletions.
9 changes: 9 additions & 0 deletions NKL/Trace.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.Trace.Types
--import NKL.Trace.Basic
import NKL.Trace.Builtin
--import NKL.Trace.Python
98 changes: 98 additions & 0 deletions NKL/Trace/Builtin.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.KLR
import NKL.Trace.Types

/-
# Utilities for creating Builtins and Globals
-/

namespace NKL.Trace
open NKL.KLR

abbrev BuiltinAttr := String -> ErrorM Term
abbrev GlobalAttr := String -> TraceM Term

abbrev BuiltinFn := List Expr -> List (String × Expr) -> ErrorM Term
abbrev GlobalFn := List Term -> List (String × Term) -> TraceM Term

def noattrs [Monad m] [MonadExcept String m] : Name -> String -> m a :=
fun name attr => throw s!"{attr} is not an attribute of {name}"

def uncallable [Monad m] [MonadExcept String m] : Name -> a -> b -> m c :=
fun name _ _ => throw s!"{name} is not a callable type"

-- Create a built-in representing a function; no attributes supported.

def simple_function (name : Name) (f : BuiltinFn) : Builtin :=
{ name := name
, type := .any name
, attr := noattrs name
, call := f
}

-- Create a built-in representing a function; basic attributes supported.

def python_function (name : Name) (f : BuiltinFn) : Builtin :=
{ name := name
, type := .any name
, attr := attrs
, call := f
}
where
attrs : BuiltinAttr
| "__name__" => return .expr (.const $ .string name.toString) .string
| "__call__" => return .object (simple_function name f)
| a => throw s!"unsupported attribute {a}"

-- Create a built-in representing a simple object from a list of attributes.

def simple_object {a : Type}
(name : Name)
(attrs : List (String × (a -> BuiltinFn)))
(x : a) : Builtin :=
{ name := name
, type := .none
, attr := attr_fn
, call := uncallable name
}
where
attr_fn (attr : String) : Except String Term :=
match attrs.find? (fun x => x.fst == attr) with
| none => .error s!"{attr} is not an attribute of {name}"
| some (_,fn) => .ok (.object $ simple_function (name.str attr) (fn x))


-- Basic Python types could be represented as built-ins. In practice, we don't
-- do this as it is more convenient to have tuples, etc. represented directly
-- in the `Term` type. However, as an example, the tuple class may be defined
-- similar to below.

-- Note: not used
def tuple_obj : List Term -> Builtin :=
simple_object "tuple".toName
[ ("count", fun l _ _ => .ok (.expr (.const $ .int l.length) .int))
, ("index", index_fn)
]
where
index_fn : List Term -> BuiltinFn
| l, [x], [] => match l.indexOf? (.expr x .none) with
| none => throw s!"{repr x} not in tuple"
| some i => return .expr (.const $ .int i) .int
| _, _, _ => throw "invalid arguments"

-- Note: not used
def tuple_class : Global :=
let name := "class_tuple".toName
{ name := name
, attr := noattrs name
, call := make_tuple
}
where
make_tuple : GlobalFn
| args, [] => return .object (tuple_obj args)
| _, _ => throw "invalid arguments"
268 changes: 268 additions & 0 deletions NKL/Trace/Types.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import Lean
import NKL.KLR
import NKL.Python

/-
# Basic types for tracing
Tracing is a special form of partial evaluation. After parsing the original
python terms, they are "traced" to produce simpler KLR terms. Therefore the
input to tracing is a `NKL.Python` AST, and the output is a `NKL.KLR` AST. The
tracing process introduces an intermediate AST, called `Term` which is an
extension of the `KLR.Expr` type. The `Term` type is used to represent
sub-expressions that are in the process of being traced, but not yet
representable in KLR. For example, consider the statement:
a, b = t.shape
Here, the left-hand side is a tuple of variables (which cannot be represented in
KLR), and the right-hand side is a built-in attribute of a tensor. During
tracing, the expression elaborator needs to return a tuple of variables `(a,b)`,
and the built-in `shape` attribute needs to return a tuple of integers. With
both of these in hand, the statement elaborator can expand the tuples into two
KLR assignment statements (or generate an error). The intermediate tuples are
elements of the type `Term`, and the final statements are elements of the type
`KLR.Stmt`.
Tracing takes place within a pair of nested state monads called `TraceM` (the
inner one), and `Tracer` (the outer one). Most code only needs to use the
`TraceM` monad (more explanation of `Tracer`, and why we need it, is given later
in this file). The `TraceM` monad provides access to an environment which
contains bindings for all of the local variables currently in scope.
All local variables refer to `Term`s and hence may not be fully reduced. At the
completion of tracing, all terms must be reducible to `KLR.Expr`s or an error is
generated. This requirement is due to an implicit phase separation in the design
of NKI: some terms must be eliminated by tracing, and some terms can be passed
to the compiler. KLR only represents terms which can be passed on to the
compiler. For example, assertions have to be resolved at tracing time, neither
the compiler nor the hardware can raise an assertion that the user can see
during development. Hence, KLR does not have an assert statement, and any
expressions under an assert must be resolved during tracing. Other examples are
conditional expressions, and loop bounds; both of which must be resolved during
tracing.
In addition to `Term`s, the environment can contain built-in objects. Built-in
objects are defined using Lean functions that generate either other built-ins or
terms. For example, accessing an attribute of a built-in object may produce a
built-in function, which, when called, generates a KLR expression. There are
also built-ins at the `Tracer` monad level, which can generate KLR statements as
well as modify the environment of the `TraceM` monad.
This module defines types to represent the built-ins, the environments, and the
tracing monads.
-/

namespace NKL.Trace
open NKL.KLR

-- Lean already has a perfectly nice hierarchical string type
export Lean (Name)
deriving instance Ord for Name

abbrev ErrorM := Except String

-- Terms are an extension of KLR.Expr, and they have types, which may be `any`.
-- TODO: can we get rid of any?

inductive TermType where
| none | bool | int | float | string
| any : Name -> TermType
| tuple : List TermType -> TermType
| list : List TermType -> TermType
| tensor : Dtype -> Shape -> TermType
deriving Repr, BEq

mutual
structure Builtin where
name : Name
type : TermType
attr : String -> Except String Term
call : List Expr -> List (String × Expr) -> Except String Term

inductive Term where
| object : Builtin -> Term
| tuple : List Term -> Term
| list : List Term -> Term
| expr : Expr -> TermType -> Term
end

instance : Repr Term where
reprPrec b n := match b with
| .object obj => .text s!"object<{obj.name}>"
| .tuple l => .text s!"tuple<{l.length}>"
| .list l => .text s!"list<{l.length}>"
| .expr e ty => reprPrec e n ++ ":" ++ reprPrec ty n

def Term.beq : Term -> Term -> Bool
| .tuple l , .tuple r => lst_eq l r
| .list l , .list r => lst_eq l r
| .expr l _, .expr r _ => l == r
| _, _ => false
where
lst_eq : List Term -> List Term -> Bool
| [] , [] => true
| l :: ls, r :: rs => Term.beq l r && lst_eq ls rs
| _, _ => false

instance : BEq Term where beq := Term.beq

partial def Term.type : Term -> ErrorM TermType
| .object obj => return obj.type
| .tuple l => return .tuple (<- l.mapM Term.type)
| .list l => return .tuple (<- l.mapM Term.type)
| .expr _ ty => return ty

-- Our state has a number for generating fresh names, the current source
-- location (for error reporting), and the local environment. The global
-- environment is in the `Tracer` monad (below).

export NKL.Python (Pos)
abbrev Env := Lean.RBMap Name Term compare

structure State where
fvn : Nat := 0
pos : Pos := { }
env : Env := ∅

namespace State

def ofList (l : List (String × Term)) : State :=
{ env := Lean.RBMap.ofList $ l.map fun (s,i) => (s.toName, i) }

def contains (s : State) (n : Name) : Bool :=
s.env.contains n

end State

abbrev TraceM := EStateM String State

instance : MonadLift ErrorM TraceM where
monadLift
| .ok x => return x
| .error s => throw s

-- Run a trace with an empty initial environment
def trace (m : TraceM a) : ErrorM a :=
match m.run { } with
| .ok x _ => return x
| .error s _ => throw s

-- generate a fresh name using an existing name as a prefix
def genName (name : Name := .anonymous) : TraceM Name := do
let s <- get
let n := s.fvn + 1
set { s with fvn := n }
return .num name n

-- add a new binding to the local environment
def extend (x : Name) (v : Term) : TraceM Unit :=
modify fun s => {s with env := s.env.insert x v}

-- lookup a name in the local environment
def lookup? (name : Name) : TraceM (Option Term) := do
let s <- get
return s.env.find? name

def lookup (name : Name) : TraceM Term := do
match (<- get).env.find? name with
| none => throw s!"{name} not found"
| some x => return x


-- The Tracer monad is setup similar to the TraceM monad, but "one level up".
-- Built-ins in Tracer can operate over Terms and within the TraceM monad, hence
-- they can query and modify the state of the inner monad.

-- The global environment contains module declarations, global objects, python
-- sources of user kernels, and possibly `Term`s. A Term may appear in the
-- global environment as a convient way to represent global constants, such as
-- `nki.language.psum`. Only the core tracing code needs to use the `Tracer`
-- monad.

-- The global environment also track the current source position for error
-- reporting, and keeps the list of translated statements for the current
-- kernel.

structure Global where
name : Name
attr : String -> TraceM Term
call : List Term -> List (String × Term) -> TraceM Term

inductive Item where
| module : Name -> Item
| global : Global -> Item
| source : Python.Fun -> Item
| term : Term -> Item

def Item.type : Item -> ErrorM TermType
| .module n => return .any n
| .global g => return .any g.name
| .source _ => return .any "source".toName
| .term t => t.type

structure Globals where
env : Lean.RBMap Name Item compare
body : Array Stmt

-- The outer monad
abbrev Tracer := StateT Globals TraceM

def getState : Tracer State := (get : TraceM State)
def setState (s : State) : Tracer Unit := set s

def getPos : Tracer Pos :=
fun g s => .ok (s.pos, g) s

def withPos (p : Pos) (m : Tracer a) : Tracer a :=
fun g s => m g { s with pos := p }

def withSrc (source : String) (m : Tracer a) : Tracer a :=
try withPos {} m
catch e => do
let pos <- getPos
throw (Python.Parsing.genError source e pos)

-- Enter a new scope, replacing the local state on exit. Note: we preserve the
-- position in the error case so the error handler (above) will get the
-- correct position on error. This should be fine since we are setting the
-- position on every statement and expression while traversing the tree.

def enter (m : Tracer a) : Tracer a :=
fun g s => match m g s with
| .ok (x, g') _ => .ok (x, g') s
| .error err s' => .error err { s with pos := s'.pos }

-- Enter a new function scope, removing all local bindings

def enterFun (m : Tracer a) : Tracer a :=
enter $ do
let s <- getState
setState { s with env := ∅ }
m

def extend_global (name : Name) (i : Item) : Tracer Unit :=
modify fun s => { s with env := s.env.insert name i }

def lookup_global (name : Name) : Tracer Item := do
match (<- get).env.find? name with
| none => throw s!"{name} not found"
| some x => return x

def lookup_item (name : Name) : Tracer Item := do
match (<- lookup? name) with
| some x => return .term x
| none => lookup_global name

def add_stmt (s : Stmt) : Tracer Unit :=
modify fun g => { g with body := g.body.push s }

def tracer (g : Globals) (m : Tracer a) : ErrorM a :=
match trace (m.run g) with
| .ok x => .ok x.fst
| .error s => .error s

0 comments on commit aad52cc

Please sign in to comment.