Skip to content

Commit

Permalink
feat: simple pretty-printer for KLR
Browse files Browse the repository at this point in the history
This change adds two related things: a pretty-printer for KLR terms,
and tensor names. Tensor names make the pretty printing nicer, but
have a second purpose. By naming all of the tensors, we can scan a KLR
kernel to collect up all of the input, output, and intermediate
tensors that will be needed to run the kernel. For argument tensors,
the generated tensor names are changed to the argument variable names;
this is just for readability.
  • Loading branch information
govereau committed Jan 22, 2025
1 parent 8757026 commit b94f357
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 20 deletions.
4 changes: 3 additions & 1 deletion NKL/FFI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import Lean
import NKL.KLR.Pretty
import NKL.Python
import NKL.Trace

namespace NKL
open NKL.KLR

local instance : MonadLift (Except String) IO where
monadLift
Expand All @@ -19,4 +21,4 @@ def parse_json (s : String) : IO Unit := do
let kernel <- Python.Parsing.parse s
let stmts <- NKL.Trace.runNKIKernel kernel
for s in stmts do
IO.println s!"{repr s}"
IO.println (" " ++ Lean.format s) --s!"{s}\n{repr s}"
22 changes: 14 additions & 8 deletions NKL/KLR/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,22 @@ portable format, a.k.a. Kernel Language Representation (KLR).

namespace NKL.KLR

-- TODO switch to tensor lib
-- TODO switch to TensorLib's version of these types
--export TensorLib (Tensor Dtype Shape)
-- Mostly, NKL deals with empty tensors, so just check dtype and shape
-- TODO: talk to Sean about a more general BEq for Tensor
--instance : BEq Tensor where
-- beq t₁ t₂ := t₁.dtype == t₂.dtype && t₁.shape == t₂.shape

abbrev Dtype := String
abbrev Shape := List Int
structure Tensor where

/-
A TensorName is essentially a typed variable, where the type
must be a tensor type. When we flush out Typ below we may replace
this with `Expr.var name (Typ.tensor dtype shape)`. For now, this
only refers to dynamic tensors, or compile-time tensors, not
trace-time tensors.
-/

structure TensorName where
name : String
dtype : Dtype
shape : Shape
deriving Repr, BEq
Expand Down Expand Up @@ -71,7 +77,7 @@ def toInt : Const -> Except String Int

end Const

-- This correspondes to the "Quasi-Affine Expressions" in Neuron.
-- This corresponds to the "Quasi-Affine Expressions" in Neuron.
-- Note, `floor` is the usual integer division.
inductive IndexExpr where
| var (name : String)
Expand All @@ -94,7 +100,7 @@ inductive Index where
inductive Expr where
| var (x : String)
| const (c : Const)
| tensor (t : Tensor)
| tensor (t : TensorName)
| access (t : Expr) (ix : List Index)
| call (f : Expr) (args : List Expr) (kwargs : List (String × Expr))
deriving Repr, BEq
Expand Down
6 changes: 3 additions & 3 deletions NKL/KLR/Encode.lean
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ private def chkIndex (i : Index) : Bool :=

partial def encExpr : Expr -> ByteArray
| .var s => tag 0x30 [encString s]
| .tensor t => tag 0x31 [encString t.dtype, encList encInt t.shape]
| .tensor t => tag 0x31 [encString t.name, encString t.dtype, encList encInt t.shape]
| .const c => tag 0x32 [encConst c]
| .access e ix => tag 0x33 [encExpr e, encList encIndex ix]
| .call f ax kw => tag 0x34 [encExpr f, encList encExpr ax, encList encKeyword kw]
Expand All @@ -276,7 +276,7 @@ where
partial def decExpr : DecodeM Expr := do
match (<- next) with
| 0x30 => return .var (<- decString)
| 0x31 => return .tensor $ .mk (<- decString) (<- decList decInt)
| 0x31 => return .tensor $ .mk (<- decString) (<- decString) (<- decList decInt)
| 0x32 => return .const (<- decConst)
| 0x33 => return .access (<- decExpr) (<- decList decIndex)
| 0x34 => return .call (<- decExpr) (<- decList decExpr) (<- decList decKeyword)
Expand All @@ -293,7 +293,7 @@ private def ixz := Index.coord (IndexExpr.int 0)

#guard chkExpr nil
#guard chkExpr (.var "var")
#guard chkExpr (.tensor $ .mk "float32" [1,2,3])
#guard chkExpr (.tensor $ .mk "t" "float32" [1,2,3])
#guard chkExpr (.const (.int 1))
#guard chkExpr (.access nil [ixz, ixz, ixz])
#guard chkExpr (.call nil [nil, nil, nil] [("a", nil), ("b", nil)])
Expand Down
80 changes: 80 additions & 0 deletions NKL/KLR/Pretty.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.KLR.Basic

namespace NKL.KLR
open Std

/-
This is a simple pretty printer for KLR terms. At some point, we may want to
make this output valid python syntax that would parse and elaborate to the same
KLR kernel. At the moment, there are too many unknowns to spend time on this.
The format here is just for ease of debugging, feel free to modify as you wish.
-/

private def abracket (f : Format) : Format :=
Format.bracket "<" f ">"

private def ppArgs [ToFormat a] (l : List a) : Format :=
Format.joinSep l ","

def ppTensor (t : TensorName) : Format :=
"%" ++ t.name ++ abracket (t.dtype ++ ":" ++ ppArgs t.shape)

def ppConst : Const -> Format
| .none => "None"
| .bool true => "True"
| .bool false => "False"
| .int i => format i
| .float f => format f
| .string s => "\"" ++ s.push '"'

private def addParens : Nat -> Format -> Format
| 0, f => f
| _, f => f.paren

def ppIndexExpr (n : Nat) : IndexExpr -> Format
| .var x => x
| .int i => format i
| .neg e => "-" ++ ppIndexExpr (n+1) e
| .add l r => addParens n $ ppIndexExpr 1 l ++ "+" ++ ppIndexExpr 1 r
| .mul i e => addParens n $ format i ++ "*" ++ ppIndexExpr 1 e
| .floor e i => addParens n $ ppIndexExpr 1 e ++ "/" ++ format i
| .ceil e i => "ceil" ++ Format.paren (ppIndexExpr 0 e ++","++ format i)
| .mod e i => addParens n $ ppIndexExpr 1 e ++ "%" ++ format i

def ppIndexExpr? : Option IndexExpr -> Format
| none => "None"
| some e => ppIndexExpr 0 e

def ppIndex : Index -> Format
| .ellipsis => "..."
| .coord e => ppIndexExpr? e
| .slice l u s => .joinSep ([l,u,s].map ppIndexExpr?) ":"

partial def ppExpr : Expr -> Format
| .var x => x
| .const c => ppConst c
| .tensor t => ppTensor t
| .access t ix => .fill (ppExpr t ++ .sbracket (.joinSep (ix.map ppIndex) ","))
| .call f args kwargs =>
let args := args.map ppExpr
let kwargs := kwargs.map fun (x,e) => x ++ "=" ++ ppExpr e
.fill (ppExpr f ++ .paren (ppArgs (args ++ kwargs)))

def ppStmt : Stmt -> Format
| .pass => "pass"
| .expr e => ppExpr e
| .ret e => "ret" ++ ppExpr e
| .assign x e => x ++ " = " ++ ppExpr e
| .loop _ _ _ _ _ => "<loop>"

instance : ToFormat TensorName where format := ppTensor
instance : ToFormat Const where format := ppConst
instance : ToFormat IndexExpr where format := ppIndexExpr 0
instance : ToFormat Index where format := ppIndex
instance : ToFormat Expr where format := ppExpr
instance : ToFormat Stmt where format := ppStmt
13 changes: 10 additions & 3 deletions NKL/Trace/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ partial def expr' : Expr' -> Tracer Item
| .const c => return .term (<- const c)
| .tensor s dty => do
let shape <- s.mapM integer
return .term (.expr (.tensor ⟨ dty, shape ⟩) (.tensor dty shape))
let name <- genName "t".toName
return .term (.expr (.tensor ⟨ name.toString, dty, shape ⟩) (.tensor dty shape))
| .name id _ => lookup_item id.toName
| .attr (.exprPos e p) id _ => do withPos p ((<- expr' e).attr id)
| .tuple l _ => return .term (.tuple (<- l.mapM term))
Expand Down Expand Up @@ -190,6 +191,7 @@ partial def stmt' : Stmt' -> Tracer Unit
partial def bind_args (f : Fun)
(args : List Term)
(kwargs : List (String × Term))
(rename : Bool := false)
: Tracer (List (String × Term)) := do
if f.args.vararg != none || f.args.kwarg != none then
throw "var args not supported"
Expand All @@ -208,16 +210,21 @@ partial def bind_args (f : Fun)
return (x, <- term' e)
else
throw s!"argument {x} not supplied"
-- rename tensors if asked to
let argmap := if rename then argmap.map renameTensors else argmap
return argmap
where
renameTensors : String × Term -> String × Term
| (s, .expr (.tensor t) ty) => (s, .expr (.tensor {t with name := s}) ty)
| other => other

-- For a function call, first evaluate the argument in the current environment.
-- Then enter a new environment and evaluate the function statements.
partial def function_call (f : Fun)
(args : List Term)
(kwargs : List (String × Term))
: Tracer Unit := do
let args <- bind_args f args kwargs
--let args <- args.mapM fun (x,e) => return (x, e)
let args <- bind_args f args kwargs (rename:=true)
withSrc f.source $ enterFun $ do
args.forM fun (x,e) => do extend x.toName e
f.body.forM stmt
Expand Down
10 changes: 5 additions & 5 deletions NKL/Trace/Tensor.lean
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,21 @@ private def tensor_call (op : String) (args : List Expr) : Term :=

-- Unary operations on tensors

def tensor_op (op : String) (t : Tensor) : TraceM Term :=
def tensor_op (op : String) (t : TensorName) : TraceM Term :=
return tensor_call op [.tensor t]

-- Binary operations on tensors / scalars

def tensor_tensor (op : String) (l r : Tensor) : TraceM Term :=
def tensor_tensor (op : String) (l r : TensorName) : TraceM Term :=
return tensor_call op [.tensor l, .tensor r]

private def broadcast (t : Tensor) (c : Const) : Expr :=
private def broadcast (t : TensorName) (c : Const) : Expr :=
let args := t.shape.map fun i => Expr.const (.int i)
let args := .const c :: args
.call (.var "broadcast") args []

def tensor_scalar (op : String) (t : Tensor) (c : Const) : TraceM Term :=
def tensor_scalar (op : String) (t : TensorName) (c : Const) : TraceM Term :=
return tensor_call op [ .tensor t, broadcast t c]

def scalar_tensor (op : String) (c : Const) (t : Tensor) : TraceM Term :=
def scalar_tensor (op : String) (c : Const) (t : TensorName) : TraceM Term :=
return tensor_call op [ .tensor t, broadcast t c]

0 comments on commit b94f357

Please sign in to comment.