Skip to content

Commit

Permalink
feat: implement comparison operators
Browse files Browse the repository at this point in the history
  • Loading branch information
govereau committed Jan 27, 2025
1 parent c65ccb5 commit e2ab7d9
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 6 deletions.
94 changes: 92 additions & 2 deletions NKL/Trace/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,99 @@ def unOp : String -> Term -> TraceM Term
| "Not", t => return .expr (.const $ .bool (<- t.isFalse)) .bool
| op, _ => throw s!"unimp {op}"

-- Comparison operators
/-
Comparison operators
These functions implement the Python comparison operators. For tensors, these
will be promoted to per-element operations, for everything else the should be
static. For example:
# comparison of two lists containing integer constants
assert a_input.shape == b_input.shape
# comparison of two integer constants
assert a_input.shape[0] <= nl.tile_size.pmax
We only need Eq (==) and Lt (<), other operators are implemted in terms of
these two.
-/

private def exprEq : Expr -> Expr -> TraceM Bool
| .var x, .var y => return x == y
| .const c₁, .const c₂ => return c₁ == c₂
| .tensor t₁, .tensor t₂ => return t₁.name == t₂.name
| _, _ => return false

private partial def termEq : Term -> Term -> TraceM Bool
| .object o₁, .object o₂ => return o₁.name == o₂.name
| .tuple l₁, .tuple l₂
| .list l₁, .list l₂ => do
if l₁.length != l₂.length then
return false
for (x,y) in (l₁.zip l₂) do
if not (<- termEq x y) then
return false
return true
| .expr e₁ _, .expr e₂ _ => exprEq e₁ e₂
| _, _ => return false

-- Python "is" operator is the same as == for all literals, except for lists.
private def termIsIdentical : Term -> Term -> TraceM Bool
| .list _, .list _ => return false
| l, r => termEq l r

-- Python: contains operator: 1 in [1,2,3]
private def termIn (x : Term) : Term -> TraceM Bool
| .tuple l | .list l => l.anyM (termEq x)
| _ => throw "invalid use of in"

private def constLt : Const -> Const -> TraceM Bool
-- comparison between same types
| .bool b₁, .bool b₂ => return !b₁ && b₂
| .int l, .int r => return l < r
| .float l, .float r => return l < r
| .string l, .string r => return l < r
-- float promotion
| .float f, .bool b => return f < if b then 1.0 else 0.0
| .bool b, .float f => return (if b then 1.0 else 0.0) < f
| .float f, .int i => return f < .ofInt i
| .int i, .float f => return .ofInt i < f
-- int promotion
| c, .int i => return (<- c.toInt) < i
| .int i, c => return i < (<- c.toInt)
-- errors
| .string _, _ | _, .string _ => throw "unsupported comparison"
| .none, _ | _, .none => throw "unsupported comparison"

private def termLt : Term -> Term -> TraceM Bool
| .tuple l₁, .tuple l₂
| .list l₁, .list l₂ => listLt l₁ l₂
| .expr (.const c₁) _, .expr (.const c₂) _ => constLt c₁ c₂
| _, _ => throw "unsupported comparison"
where
listLt : List Term -> List Term -> TraceM Bool
| [], [] => return false
| [], _ => return true
| _, [] => return false
| x :: xs, y :: ys => do
if <- termLt x y then return true
else return (<- termEq x y) && (<- listLt xs ys)

def cmpOp : String -> Term -> Term -> TraceM Bool
| s, l, r => throw s!"unsupported comparison operator {s} {repr l} {repr r}"
| "Eq", l, r => termEq l r
| "NotEq", l, r => return not (<- termEq l r)
| "Lt", l, r => termLt l r
| "LtE", l, r => return (<- termEq l r) || (<- termLt l r)
| "Gt", l, r => return not (<- termEq l r) && not (<- termLt l r)
| "GtE", l, r => return not (<- termLt l r)
| "Is", l, r => termIsIdentical l r
| "IsNot", l, r => return not (<- termIsIdentical l r)
| "In", l, r => termIn l r
| "NotIn", l, r => return not (<- termIn l r)
| op, _, _ => throw s!"unsupported comparison operator {op}"

-- Python comparison chains are short-circuting
-- e.g. x < y < z => x < y || y < z
def compare : Term -> List String -> List Term -> TraceM Term
| x, [op], [y] => return bool (<- cmpOp op x y)
| x, op::ops, y::ys => do
Expand All @@ -164,6 +252,8 @@ def compare : Term -> List String -> List Term -> TraceM Term
where
bool b := .expr (.const $ .bool b) .bool

-- Attributes

def Term.attr : Term -> String -> TraceM Term
| .object o, id => o.attr id
| .expr _ (.tensor d _), "dtype" => return (str d)
Expand Down
4 changes: 2 additions & 2 deletions NKL/Trace/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ partial def klr (e : Expr) : Tracer KLR.Expr :=

partial def integer (e : Expr) : Tracer Int := do
match <- klr e with
| .const c => return (<- c.toInt)
| _ => throw "expecting integer"
| .const c => return (<- c.toInt)
| _ => throw "expecting integer"

partial def expr' : Expr' -> Tracer Item
| .const c => return .term (<- const c)
Expand Down
32 changes: 30 additions & 2 deletions interop/test/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# This file exercises the Lean partial evaluator with
# a set of basic unit tests. Each function is parsed,
# handed to Lean, where it is checked and reduced to KLR.

import numpy as np
import nki
import pytest
Expand Down Expand Up @@ -65,6 +69,28 @@ def expr_bool_op(t):
1 or None # evals to 1
(False,) or 1 # evals to (False,)

def expr_cmp_op(t):
assert 1 == 1
assert [] == []
assert not ([1,2] == [1])
assert not ([] < [])
assert [] < [1]
assert not ([1,2] < [1,2])
assert [1,1] < [1,2]
assert [1,2] < [1,2,3]
assert 1.2 < 2
assert 1 < 1.2
assert 1.2 < 1.3
assert 0.5 < True
assert not (0.5 < False)
assert "a" < "ab"
assert (1,2) is (1,2)
assert not ([1,2] is [1,2])
assert 1 in (1,2)
assert 1 in [3,2,1]
assert 1 not in (2,3,4)
assert 1 not in []

def assign(t):
x = y = 1
assert x == y
Expand All @@ -79,19 +105,21 @@ def assign(t):
assert b == 3
assert c == 4

# test each function in turn
@pytest.mark.parametrize("f", [
const_stmt,
expr_name,
expr_tuple,
expr_list,
expr_subscript,
expr_bool_op,
expr_cmp_op,
assign
])
def test_succeed(f):
t = np.ndarray(10)
F = Parser(f)
F(t)
F = Parser(f) # parse python
F(t) # specialize, and reduce to KLR

# Failing cases
# (These functions are expected to fail elaboration to KLR)
Expand Down

0 comments on commit e2ab7d9

Please sign in to comment.