Skip to content

Commit

Permalink
fix: UDF, UDAF, UDWF with_alias(..) should wrap the inner function fully
Browse files Browse the repository at this point in the history
  • Loading branch information
Blizzara committed Aug 21, 2024
1 parent 37e54ee commit 51bc3ba
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 5 deletions.
60 changes: 59 additions & 1 deletion datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
/// not implement the method, returns an error. Order insensitive and hard
/// requirement aggregators return `Ok(None)`.
fn with_beneficial_ordering(
self: Arc<Self>,
&self,
_beneficial_ordering: bool,
) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
if self.order_sensitivity().is_beneficial() {
Expand Down Expand Up @@ -608,6 +608,60 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
&self.aliases
}

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
self.inner.state_fields(args)
}

fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
self.inner.groups_accumulator_supported(args)
}

fn create_groups_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
self.inner.create_groups_accumulator(args)
}

fn create_sliding_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
self.inner.accumulator(args)
}

fn with_beneficial_ordering(
&self,
beneficial_ordering: bool,
) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
self.inner
.with_beneficial_ordering(beneficial_ordering)
.map(|udf| {
udf.map(|udf| {
Arc::new(AliasedAggregateUDFImpl {
inner: udf,
aliases: self.aliases.clone(),
}) as Arc<dyn AggregateUDFImpl>
})
})
}

fn order_sensitivity(&self) -> AggregateOrderSensitivity {
self.inner.order_sensitivity()
}

fn simplify(&self) -> Option<AggregateFunctionSimplification> {
self.inner.simplify()
}

fn reverse_expr(&self) -> ReversedUDAF {
self.inner.reverse_expr()
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
self.inner.coerce_types(arg_types)
}

fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() {
self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
Expand All @@ -622,6 +676,10 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
self.aliases.hash(hasher);
hasher.finish()
}

fn is_descending(&self) -> Option<bool> {
self.inner.is_descending()
}
}

/// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers
Expand Down
57 changes: 55 additions & 2 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,14 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
self.inner.name()
}

fn display_name(&self, args: &[Expr]) -> Result<String> {
self.inner.display_name(args)
}

fn schema_name(&self, args: &[Expr]) -> Result<String> {
self.inner.schema_name(args)
}

fn signature(&self) -> &Signature {
self.inner.signature()
}
Expand All @@ -632,12 +640,57 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
self.inner.return_type(arg_types)
}

fn aliases(&self) -> &[String] {
&self.aliases
}

fn return_type_from_exprs(
&self,
args: &[Expr],
schema: &dyn ExprSchema,
arg_types: &[DataType],
) -> Result<DataType> {
self.inner.return_type_from_exprs(args, schema, arg_types)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
self.inner.invoke(args)
}

fn aliases(&self) -> &[String] {
&self.aliases
fn invoke_no_args(&self, number_rows: usize) -> Result<ColumnarValue> {
self.inner.invoke_no_args(number_rows)
}

fn simplify(
&self,
args: Vec<Expr>,
info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
self.inner.simplify(args, info)
}

fn short_circuits(&self) -> bool {
self.inner.short_circuits()
}

fn evaluate_bounds(&self, input: &[&Interval]) -> Result<Interval> {
self.inner.evaluate_bounds(input)
}

fn propagate_constraints(
&self,
interval: &Interval,
inputs: &[&Interval],
) -> Result<Option<Vec<Interval>>> {
self.inner.propagate_constraints(interval, inputs)
}

fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
self.inner.output_ordering(inputs)
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
self.inner.coerce_types(arg_types)
}

fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
Expand Down
16 changes: 16 additions & 0 deletions datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,10 @@ impl WindowUDFImpl for AliasedWindowUDFImpl {
&self.aliases
}

fn simplify(&self) -> Option<WindowFunctionSimplification> {
self.inner.simplify()
}

fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<AliasedWindowUDFImpl>() {
self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
Expand All @@ -442,6 +446,18 @@ impl WindowUDFImpl for AliasedWindowUDFImpl {
self.aliases.hash(hasher);
hasher.finish()
}

fn nullable(&self) -> bool {
self.inner.nullable()
}

fn sort_options(&self) -> Option<SortOptions> {
self.inner.sort_options()
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
self.inner.coerce_types(arg_types)
}
}

/// Implementation of [`WindowUDFImpl`] that wraps the function style pointers
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ impl AggregateUDFImpl for FirstValue {
}

fn with_beneficial_ordering(
self: Arc<Self>,
&self,
beneficial_ordering: bool,
) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
Ok(Some(Arc::new(
Expand Down Expand Up @@ -451,7 +451,7 @@ impl AggregateUDFImpl for LastValue {
}

fn with_beneficial_ordering(
self: Arc<Self>,
&self,
beneficial_ordering: bool,
) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
Ok(Some(Arc::new(
Expand Down

0 comments on commit 51bc3ba

Please sign in to comment.