diff --git a/NKL/Trace/Basic.lean b/NKL/Trace/Basic.lean index 83167a5..aa3cacb 100644 --- a/NKL/Trace/Basic.lean +++ b/NKL/Trace/Basic.lean @@ -149,11 +149,99 @@ def unOp : String -> Term -> TraceM Term | "Not", t => return .expr (.const $ .bool (<- t.isFalse)) .bool | op, _ => throw s!"unimp {op}" --- Comparison operators +/- +Comparison operators + +These functions implement the Python comparison operators. For tensors, these +will be promoted to per-element operations, for everything else the should be +static. For example: + + # comparison of two lists containing integer constants + assert a_input.shape == b_input.shape + + # comparison of two integer constants + assert a_input.shape[0] <= nl.tile_size.pmax + +We only need Eq (==) and Lt (<), other operators are implemted in terms of +these two. +-/ + +private def exprEq : Expr -> Expr -> TraceM Bool + | .var x, .var y => return x == y + | .const c₁, .const c₂ => return c₁ == c₂ + | .tensor t₁, .tensor t₂ => return t₁.name == t₂.name + | _, _ => return false + +private partial def termEq : Term -> Term -> TraceM Bool + | .object o₁, .object o₂ => return o₁.name == o₂.name + | .tuple l₁, .tuple l₂ + | .list l₁, .list l₂ => do + if l₁.length != l₂.length then + return false + for (x,y) in (l₁.zip l₂) do + if not (<- termEq x y) then + return false + return true + | .expr e₁ _, .expr e₂ _ => exprEq e₁ e₂ + | _, _ => return false + +-- Python "is" operator is the same as == for all literals, except for lists. +private def termIsIdentical : Term -> Term -> TraceM Bool + | .list _, .list _ => return false + | l, r => termEq l r + +-- Python: contains operator: 1 in [1,2,3] +private def termIn (x : Term) : Term -> TraceM Bool + | .tuple l | .list l => l.anyM (termEq x) + | _ => throw "invalid use of in" + +private def constLt : Const -> Const -> TraceM Bool + -- comparison between same types + | .bool b₁, .bool b₂ => return !b₁ && b₂ + | .int l, .int r => return l < r + | .float l, .float r => return l < r + | .string l, .string r => return l < r + -- float promotion + | .float f, .bool b => return f < if b then 1.0 else 0.0 + | .bool b, .float f => return (if b then 1.0 else 0.0) < f + | .float f, .int i => return f < .ofInt i + | .int i, .float f => return .ofInt i < f + -- int promotion + | c, .int i => return (<- c.toInt) < i + | .int i, c => return i < (<- c.toInt) + -- errors + | .string _, _ | _, .string _ => throw "unsupported comparison" + | .none, _ | _, .none => throw "unsupported comparison" + +private def termLt : Term -> Term -> TraceM Bool + | .tuple l₁, .tuple l₂ + | .list l₁, .list l₂ => listLt l₁ l₂ + | .expr (.const c₁) _, .expr (.const c₂) _ => constLt c₁ c₂ + | _, _ => throw "unsupported comparison" +where + listLt : List Term -> List Term -> TraceM Bool + | [], [] => return false + | [], _ => return true + | _, [] => return false + | x :: xs, y :: ys => do + if <- termLt x y then return true + else return (<- termEq x y) && (<- listLt xs ys) def cmpOp : String -> Term -> Term -> TraceM Bool - | s, l, r => throw s!"unsupported comparison operator {s} {repr l} {repr r}" + | "Eq", l, r => termEq l r + | "NotEq", l, r => return not (<- termEq l r) + | "Lt", l, r => termLt l r + | "LtE", l, r => return (<- termEq l r) || (<- termLt l r) + | "Gt", l, r => return not (<- termEq l r) && not (<- termLt l r) + | "GtE", l, r => return not (<- termLt l r) + | "Is", l, r => termIsIdentical l r + | "IsNot", l, r => return not (<- termIsIdentical l r) + | "In", l, r => termIn l r + | "NotIn", l, r => return not (<- termIn l r) + | op, _, _ => throw s!"unsupported comparison operator {op}" +-- Python comparison chains are short-circuting +-- e.g. x < y < z => x < y || y < z def compare : Term -> List String -> List Term -> TraceM Term | x, [op], [y] => return bool (<- cmpOp op x y) | x, op::ops, y::ys => do @@ -164,6 +252,8 @@ def compare : Term -> List String -> List Term -> TraceM Term where bool b := .expr (.const $ .bool b) .bool +-- Attributes + def Term.attr : Term -> String -> TraceM Term | .object o, id => o.attr id | .expr _ (.tensor d _), "dtype" => return (str d) diff --git a/NKL/Trace/Python.lean b/NKL/Trace/Python.lean index 85106bf..dcf29ec 100644 --- a/NKL/Trace/Python.lean +++ b/NKL/Trace/Python.lean @@ -208,8 +208,8 @@ partial def klr (e : Expr) : Tracer KLR.Expr := partial def integer (e : Expr) : Tracer Int := do match <- klr e with - | .const c => return (<- c.toInt) - | _ => throw "expecting integer" + | .const c => return (<- c.toInt) + | _ => throw "expecting integer" partial def expr' : Expr' -> Tracer Item | .const c => return .term (<- const c) diff --git a/interop/test/test_basic.py b/interop/test/test_basic.py index dfbfb0e..4c35cec 100644 --- a/interop/test/test_basic.py +++ b/interop/test/test_basic.py @@ -1,3 +1,7 @@ +# This file exercises the Lean partial evaluator with +# a set of basic unit tests. Each function is parsed, +# handed to Lean, where it is checked and reduced to KLR. + import numpy as np import nki import pytest @@ -65,6 +69,28 @@ def expr_bool_op(t): 1 or None # evals to 1 (False,) or 1 # evals to (False,) +def expr_cmp_op(t): + assert 1 == 1 + assert [] == [] + assert not ([1,2] == [1]) + assert not ([] < []) + assert [] < [1] + assert not ([1,2] < [1,2]) + assert [1,1] < [1,2] + assert [1,2] < [1,2,3] + assert 1.2 < 2 + assert 1 < 1.2 + assert 1.2 < 1.3 + assert 0.5 < True + assert not (0.5 < False) + assert "a" < "ab" + assert (1,2) is (1,2) + assert not ([1,2] is [1,2]) + assert 1 in (1,2) + assert 1 in [3,2,1] + assert 1 not in (2,3,4) + assert 1 not in [] + def assign(t): x = y = 1 assert x == y @@ -79,6 +105,7 @@ def assign(t): assert b == 3 assert c == 4 +# test each function in turn @pytest.mark.parametrize("f", [ const_stmt, expr_name, @@ -86,12 +113,13 @@ def assign(t): expr_list, expr_subscript, expr_bool_op, + expr_cmp_op, assign ]) def test_succeed(f): t = np.ndarray(10) - F = Parser(f) - F(t) + F = Parser(f) # parse python + F(t) # specialize, and reduce to KLR # Failing cases # (These functions are expected to fail elaboration to KLR)