Skip to content

Commit

Permalink
Implement inequality joins by translation to conditional joins (#17000)
Browse files Browse the repository at this point in the history
Implement inequality joins by using the newly-exposed conditional join from pylibcudf.

- Closes #16926

Authors:
  - Lawrence Mitchell (https://github.com/wence-)
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Matthew Roeschke (https://github.com/mroeschke)

URL: #17000
  • Loading branch information
wence- authored Nov 8, 2024
1 parent e8935b9 commit 150d8d8
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 62 deletions.
2 changes: 2 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from cudf_polars.dsl.expressions.base import (
AggInfo,
Col,
ColRef,
Expr,
NamedExpr,
)
Expand All @@ -40,6 +41,7 @@
"LiteralColumn",
"Len",
"Col",
"ColRef",
"BooleanFunction",
"StringFunction",
"TemporalFunction",
Expand Down
35 changes: 34 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/expressions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from cudf_polars.containers import Column, DataFrame

__all__ = ["Expr", "NamedExpr", "Col", "AggInfo", "ExecutionContext"]
__all__ = ["Expr", "NamedExpr", "Col", "AggInfo", "ExecutionContext", "ColRef"]


class AggInfo(NamedTuple):
Expand Down Expand Up @@ -249,3 +249,36 @@ def do_evaluate(
def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
return AggInfo([(self, plc.aggregation.collect_list(), self)])


class ColRef(Expr):
__slots__ = ("index", "table_ref")
_non_child = ("dtype", "index", "table_ref")
index: int
table_ref: plc.expressions.TableReference

def __init__(
self,
dtype: plc.DataType,
index: int,
table_ref: plc.expressions.TableReference,
column: Expr,
) -> None:
if not isinstance(column, Col):
raise TypeError("Column reference should only apply to columns")
self.dtype = dtype
self.index = index
self.table_ref = table_ref
self.children = (column,)

def do_evaluate(
self,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
) -> Column:
"""Evaluate this expression given a dataframe for context."""
raise NotImplementedError(
"Only expect this node as part of an expression translated to libcudf AST."
)
70 changes: 69 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
import cudf_polars.dsl.expr as expr
from cudf_polars.containers import Column, DataFrame
from cudf_polars.dsl.nodebase import Node
from cudf_polars.dsl.to_ast import to_parquet_filter
from cudf_polars.dsl.to_ast import to_ast, to_parquet_filter
from cudf_polars.utils import dtypes
from cudf_polars.utils.versions import POLARS_VERSION_GT_112

if TYPE_CHECKING:
from collections.abc import Callable, Hashable, MutableMapping, Sequence
Expand All @@ -48,6 +49,7 @@
"Select",
"GroupBy",
"Join",
"ConditionalJoin",
"HStack",
"Distinct",
"Sort",
Expand Down Expand Up @@ -522,6 +524,12 @@ def do_evaluate(
) # pragma: no cover; post init trips first
if row_index is not None:
name, offset = row_index
if POLARS_VERSION_GT_112:
# If we sliced away some data from the start, that
# shifts the row index.
# But prior to 1.13, polars had this wrong, so we match behaviour
# https://github.com/pola-rs/polars/issues/19607
offset += skip_rows # pragma: no cover; polars 1.13 not yet released
dtype = schema[name]
step = plc.interop.from_arrow(
pa.scalar(1, type=plc.interop.to_arrow(dtype))
Expand Down Expand Up @@ -890,6 +898,66 @@ def do_evaluate(
return DataFrame(broadcasted).slice(options.slice)


class ConditionalJoin(IR):
"""A conditional inner join of two dataframes on a predicate."""

__slots__ = ("predicate", "options", "ast_predicate")
_non_child = ("schema", "predicate", "options")
predicate: expr.Expr
options: tuple

def __init__(
self, schema: Schema, predicate: expr.Expr, options: tuple, left: IR, right: IR
) -> None:
self.schema = schema
self.predicate = predicate
self.options = options
self.children = (left, right)
self.ast_predicate = to_ast(predicate)
_, join_nulls, zlice, suffix, coalesce = self.options
# Preconditions from polars
assert not join_nulls
assert not coalesce
if self.ast_predicate is None:
raise NotImplementedError(
f"Conditional join with predicate {predicate}"
) # pragma: no cover; polars never delivers expressions we can't handle
self._non_child_args = (self.ast_predicate, zlice, suffix)

@classmethod
def do_evaluate(
cls,
predicate: plc.expressions.Expression,
zlice: tuple[int, int] | None,
suffix: str,
left: DataFrame,
right: DataFrame,
) -> DataFrame:
"""Evaluate and return a dataframe."""
lg, rg = plc.join.conditional_inner_join(left.table, right.table, predicate)
left = DataFrame.from_table(
plc.copying.gather(
left.table, lg, plc.copying.OutOfBoundsPolicy.DONT_CHECK
),
left.column_names,
)
right = DataFrame.from_table(
plc.copying.gather(
right.table, rg, plc.copying.OutOfBoundsPolicy.DONT_CHECK
),
right.column_names,
)
right = right.rename_columns(
{
name: f"{name}{suffix}"
for name in right.column_names
if name in left.column_names_set
}
)
result = left.with_columns(right.columns)
return result.slice(zlice)


class Join(IR):
"""A join of two dataframes."""

Expand Down
79 changes: 67 additions & 12 deletions python/cudf_polars/cudf_polars/dsl/to_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
from pylibcudf import expressions as plc_expr

from cudf_polars.dsl import expr
from cudf_polars.dsl.traversal import CachingVisitor
from cudf_polars.dsl.traversal import CachingVisitor, reuse_if_unchanged
from cudf_polars.typing import GenericTransformer

if TYPE_CHECKING:
from collections.abc import Mapping

from cudf_polars.typing import ExprTransformer

# Can't merge these op-mapping dictionaries because scoped enum values
# are exposed by cython with equality/hash based one their underlying
# representation type. So in a dict they are just treated as integers.
Expand Down Expand Up @@ -128,7 +130,14 @@ def _to_ast(node: expr.Expr, self: Transformer) -> plc_expr.Expression:
def _(node: expr.Col, self: Transformer) -> plc_expr.Expression:
if self.state["for_parquet"]:
return plc_expr.ColumnNameReference(node.name)
return plc_expr.ColumnReference(self.state["name_to_index"][node.name])
raise TypeError("Should always be wrapped in a ColRef node before translation")


@_to_ast.register
def _(node: expr.ColRef, self: Transformer) -> plc_expr.Expression:
if self.state["for_parquet"]:
raise TypeError("Not expecting ColRef node in parquet filter")
return plc_expr.ColumnReference(node.index, node.table_ref)


@_to_ast.register
Expand Down Expand Up @@ -238,28 +247,74 @@ def to_parquet_filter(node: expr.Expr) -> plc_expr.Expression | None:
return None


def to_ast(
node: expr.Expr, *, name_to_index: Mapping[str, int]
) -> plc_expr.Expression | None:
def to_ast(node: expr.Expr) -> plc_expr.Expression | None:
"""
Convert an expression to libcudf AST nodes suitable for compute_column.
Parameters
----------
node
Expression to convert.
name_to_index
Mapping from column names to their index in the table that
will be used for expression evaluation.
Notes
-----
`Col` nodes must always be wrapped in `TableRef` nodes when
converting to an ast expression so that their table reference and
index are provided.
Returns
-------
pylibcudf Expressoin if conversion is possible, otherwise None.
pylibcudf Expression if conversion is possible, otherwise None.
"""
mapper = CachingVisitor(
_to_ast, state={"for_parquet": False, "name_to_index": name_to_index}
)
mapper = CachingVisitor(_to_ast, state={"for_parquet": False})
try:
return mapper(node)
except (KeyError, NotImplementedError):
return None


def _insert_colrefs(node: expr.Expr, rec: ExprTransformer) -> expr.Expr:
if isinstance(node, expr.Col):
return expr.ColRef(
node.dtype,
rec.state["name_to_index"][node.name],
rec.state["table_ref"],
node,
)
return reuse_if_unchanged(node, rec)


def insert_colrefs(
node: expr.Expr,
*,
table_ref: plc.expressions.TableReference,
name_to_index: Mapping[str, int],
) -> expr.Expr:
"""
Insert column references into an expression before conversion to libcudf AST.
Parameters
----------
node
Expression to insert references into.
table_ref
pylibcudf `TableReference` indicating whether column
references are coming from the left or right table.
name_to_index:
Mapping from column names to column indices in the table
eventually used for evaluation.
Notes
-----
All column references are wrapped in the same, singular, table
reference, so this function relies on the expression only
containing column references from a single table.
Returns
-------
New expression with column references inserted.
"""
mapper = CachingVisitor(
_insert_colrefs, state={"table_ref": table_ref, "name_to_index": name_to_index}
)
return mapper(node)
68 changes: 25 additions & 43 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json
from contextlib import AbstractContextManager, nullcontext
from functools import singledispatch
from typing import TYPE_CHECKING, Any
from typing import Any

import pyarrow as pa
from typing_extensions import assert_never
Expand All @@ -21,13 +21,10 @@
import pylibcudf as plc

from cudf_polars.dsl import expr, ir
from cudf_polars.dsl.traversal import make_recursive, reuse_if_unchanged
from cudf_polars.dsl.to_ast import insert_colrefs
from cudf_polars.typing import NodeTraverser
from cudf_polars.utils import dtypes, sorting

if TYPE_CHECKING:
from cudf_polars.typing import ExprTransformer

__all__ = ["translate_ir", "translate_named_expr"]


Expand Down Expand Up @@ -204,55 +201,40 @@ def _(
raise NotImplementedError(
f"Unsupported join type {how}"
) # pragma: no cover; asof joins not yet exposed
# No exposure of mixed/conditional joins in pylibcudf yet, so in
# the first instance, implement by doing a cross join followed by
# a filter.
_, join_nulls, zlice, suffix, coalesce = node.options
cross = ir.Join(
schema,
[],
[],
("cross", join_nulls, None, suffix, coalesce),
inp_left,
inp_right,
)
dtype = plc.DataType(plc.TypeId.BOOL8)
if op2 is None:
ops = [op1]
else:
ops = [op1, op2]
suffix = cross.options[3]

# Column references in the right table refer to the post-join
# names, so with suffixes.
def _rename(e: expr.Expr, rec: ExprTransformer) -> expr.Expr:
if isinstance(e, expr.Col) and e.name in inp_left.schema:
return type(e)(e.dtype, f"{e.name}{suffix}")
return reuse_if_unchanged(e, rec)

mapper = make_recursive(_rename)
right_on = [
expr.NamedExpr(
f"{old.name}{suffix}" if old.name in inp_left.schema else old.name, new
)
for new, old in zip(
(mapper(e.value) for e in right_on), right_on, strict=True
)
]
mask = functools.reduce(

dtype = plc.DataType(plc.TypeId.BOOL8)
predicate = functools.reduce(
functools.partial(
expr.BinOp, dtype, plc.binaryop.BinaryOperator.LOGICAL_AND
),
(
expr.BinOp(dtype, expr.BinOp._MAPPING[op], left.value, right.value)
expr.BinOp(
dtype,
expr.BinOp._MAPPING[op],
insert_colrefs(
left.value,
table_ref=plc.expressions.TableReference.LEFT,
name_to_index={
name: i for i, name in enumerate(inp_left.schema)
},
),
insert_colrefs(
right.value,
table_ref=plc.expressions.TableReference.RIGHT,
name_to_index={
name: i for i, name in enumerate(inp_right.schema)
},
),
)
for op, left, right in zip(ops, left_on, right_on, strict=True)
),
)
filtered = ir.Filter(schema, expr.NamedExpr("mask", mask), cross)
if zlice is not None:
offset, length = zlice
return ir.Slice(schema, offset, length, filtered)
return filtered

return ir.ConditionalJoin(schema, predicate, node.options, inp_left, inp_right)


@_translate_ir.register
Expand Down
2 changes: 2 additions & 0 deletions python/cudf_polars/cudf_polars/utils/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

POLARS_VERSION_LT_111 = POLARS_VERSION < parse("1.11")
POLARS_VERSION_LT_112 = POLARS_VERSION < parse("1.12")
POLARS_VERSION_GT_112 = POLARS_VERSION > parse("1.12")
POLARS_VERSION_LT_113 = POLARS_VERSION < parse("1.13")


def _ensure_polars_version():
Expand Down
Loading

0 comments on commit 150d8d8

Please sign in to comment.