Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement comparison operators #24

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 92 additions & 2 deletions NKL/Trace/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,99 @@ def unOp : String -> Term -> TraceM Term
| "Not", t => return .expr (.const $ .bool (<- t.isFalse)) .bool
| op, _ => throw s!"unimp {op}"

-- Comparison operators
/-
Comparison operators

These functions implement the Python comparison operators. For tensors, these
will be promoted to per-element operations, for everything else the should be
static. For example:

# comparison of two lists containing integer constants
assert a_input.shape == b_input.shape

# comparison of two integer constants
assert a_input.shape[0] <= nl.tile_size.pmax

We only need Eq (==) and Lt (<), other operators are implemted in terms of
these two.
-/

private def exprEq : Expr -> Expr -> TraceM Bool
| .var x, .var y => return x == y
| .const c₁, .const c₂ => return c₁ == c₂
| .tensor t₁, .tensor t₂ => return t₁.name == t₂.name
| _, _ => return false

private partial def termEq : Term -> Term -> TraceM Bool
| .object o₁, .object o₂ => return o₁.name == o₂.name
| .tuple l₁, .tuple l₂
| .list l₁, .list l₂ => do
if l₁.length != l₂.length then
return false
for (x,y) in (l₁.zip l₂) do
govereau marked this conversation as resolved.
Show resolved Hide resolved
if not (<- termEq x y) then
return false
return true
| .expr e₁ _, .expr e₂ _ => exprEq e₁ e₂
| _, _ => return false

-- Python "is" operator is the same as == for all literals, except for lists.
private def termIsIdentical : Term -> Term -> TraceM Bool
| .list _, .list _ => return false
govereau marked this conversation as resolved.
Show resolved Hide resolved
| l, r => termEq l r

-- Python: contains operator: 1 in [1,2,3]
private def termIn (x : Term) : Term -> TraceM Bool
| .tuple l | .list l => l.anyM (termEq x)
| _ => throw "invalid use of in"

private def constLt : Const -> Const -> TraceM Bool
govereau marked this conversation as resolved.
Show resolved Hide resolved
-- comparison between same types
| .bool b₁, .bool b₂ => return !b₁ && b₂
| .int l, .int r => return l < r
| .float l, .float r => return l < r
| .string l, .string r => return l < r
-- float promotion
| .float f, .bool b => return f < if b then 1.0 else 0.0
| .bool b, .float f => return (if b then 1.0 else 0.0) < f
| .float f, .int i => return f < .ofInt i
| .int i, .float f => return .ofInt i < f
-- int promotion
| c, .int i => return (<- c.toInt) < i
govereau marked this conversation as resolved.
Show resolved Hide resolved
| .int i, c => return i < (<- c.toInt)
-- errors
| .string _, _ | _, .string _ => throw "unsupported comparison"
| .none, _ | _, .none => throw "unsupported comparison"

private def termLt : Term -> Term -> TraceM Bool
| .tuple l₁, .tuple l₂
| .list l₁, .list l₂ => listLt l₁ l₂
| .expr (.const c₁) _, .expr (.const c₂) _ => constLt c₁ c₂
| _, _ => throw "unsupported comparison"
where
listLt : List Term -> List Term -> TraceM Bool
| [], [] => return false
| [], _ => return true
| _, [] => return false
| x :: xs, y :: ys => do
if <- termLt x y then return true
else return (<- termEq x y) && (<- listLt xs ys)

def cmpOp : String -> Term -> Term -> TraceM Bool
| s, l, r => throw s!"unsupported comparison operator {s} {repr l} {repr r}"
| "Eq", l, r => termEq l r
govereau marked this conversation as resolved.
Show resolved Hide resolved
| "NotEq", l, r => return not (<- termEq l r)
| "Lt", l, r => termLt l r
| "LtE", l, r => return (<- termEq l r) || (<- termLt l r)
| "Gt", l, r => return not (<- termEq l r) && not (<- termLt l r)
| "GtE", l, r => return not (<- termLt l r)
| "Is", l, r => termIsIdentical l r
| "IsNot", l, r => return not (<- termIsIdentical l r)
| "In", l, r => termIn l r
| "NotIn", l, r => return not (<- termIn l r)
| op, _, _ => throw s!"unsupported comparison operator {op}"

-- Python comparison chains are short-circuting
-- e.g. x < y < z => x < y || y < z
def compare : Term -> List String -> List Term -> TraceM Term
| x, [op], [y] => return bool (<- cmpOp op x y)
| x, op::ops, y::ys => do
Expand All @@ -164,6 +252,8 @@ def compare : Term -> List String -> List Term -> TraceM Term
where
bool b := .expr (.const $ .bool b) .bool

-- Attributes

def Term.attr : Term -> String -> TraceM Term
| .object o, id => o.attr id
| .expr _ (.tensor d _), "dtype" => return (str d)
Expand Down
4 changes: 2 additions & 2 deletions NKL/Trace/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ partial def klr (e : Expr) : Tracer KLR.Expr :=

partial def integer (e : Expr) : Tracer Int := do
match <- klr e with
| .const c => return (<- c.toInt)
| _ => throw "expecting integer"
| .const c => return (<- c.toInt)
| _ => throw "expecting integer"

partial def expr' : Expr' -> Tracer Item
| .const c => return .term (<- const c)
Expand Down
32 changes: 30 additions & 2 deletions interop/test/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# This file exercises the Lean partial evaluator with
# a set of basic unit tests. Each function is parsed,
# handed to Lean, where it is checked and reduced to KLR.

import numpy as np
import nki
import pytest
Expand Down Expand Up @@ -65,6 +69,28 @@ def expr_bool_op(t):
1 or None # evals to 1
(False,) or 1 # evals to (False,)

def expr_cmp_op(t):
govereau marked this conversation as resolved.
Show resolved Hide resolved
assert 1 == 1
assert [] == []
assert not ([1,2] == [1])
assert not ([] < [])
assert [] < [1]
assert not ([1,2] < [1,2])
assert [1,1] < [1,2]
assert [1,2] < [1,2,3]
assert 1.2 < 2
assert 1 < 1.2
assert 1.2 < 1.3
assert 0.5 < True
assert not (0.5 < False)
assert "a" < "ab"
assert (1,2) is (1,2)
assert not ([1,2] is [1,2])
assert 1 in (1,2)
assert 1 in [3,2,1]
assert 1 not in (2,3,4)
assert 1 not in []

def assign(t):
x = y = 1
assert x == y
Expand All @@ -79,19 +105,21 @@ def assign(t):
assert b == 3
assert c == 4

# test each function in turn
@pytest.mark.parametrize("f", [
const_stmt,
expr_name,
expr_tuple,
expr_list,
expr_subscript,
expr_bool_op,
expr_cmp_op,
assign
])
def test_succeed(f):
t = np.ndarray(10)
F = Parser(f)
F(t)
F = Parser(f) # parse python
F(t) # specialize, and reduce to KLR

# Failing cases
# (These functions are expected to fail elaboration to KLR)
Expand Down