Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql: implement strict UDFs with CASE #94797

Merged
merged 1 commit into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/bench/rttanalysis/testdata/benchmark_expectations
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ exp,benchmark
4,ORMQueries/django_column_introspection_1_table
4,ORMQueries/django_column_introspection_4_tables
4,ORMQueries/django_column_introspection_8_tables
5,ORMQueries/django_table_introspection_1_table
5,ORMQueries/django_table_introspection_8_tables
3,ORMQueries/django_table_introspection_1_table
3,ORMQueries/django_table_introspection_8_tables
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rafiss LGTY?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep!

0,ORMQueries/has_column_privilege_using_attnum
0,ORMQueries/has_column_privilege_using_column_name
0,ORMQueries/has_schema_privilege
Expand Down
73 changes: 69 additions & 4 deletions pkg/sql/logictest/testdata/logic_test/udf
Original file line number Diff line number Diff line change
Expand Up @@ -2517,6 +2517,58 @@ CREATE FUNCTION err(i imp) RETURNS INT LANGUAGE SQL AS $$
$$


subtest strict

statement ok
CREATE FUNCTION strict_fn(i INT, t TEXT, b BOOL) RETURNS INT STRICT LANGUAGE SQL AS $$
SELECT 1
$$

query I
SELECT strict_fn(1, 'foo', true)
----
1

# Same as above, but with non-constant arguments.
query I
WITH tmp(a, b, c) AS MATERIALIZED (VALUES (1, 'foo', true))
SELECT strict_fn(a, b, c) FROM tmp
----
1

query III
SELECT strict_fn(NULL, 'foo', true), strict_fn(1, NULL, true), strict_fn(1, 'foo', NULL)
----
NULL NULL NULL

query III
SELECT strict_fn(NULL, NULL, true), strict_fn(1, NULL, NULL), strict_fn(NULL, 'foo', NULL)
----
NULL NULL NULL

query I
SELECT strict_fn(NULL, NULL, NULL)
----
NULL

statement ok
CREATE FUNCTION strict_fn_imp(t TEXT, i imp) RETURNS INT RETURNS NULL ON NULL INPUT LANGUAGE SQL AS $$
SELECT 1
$$

# A tuple with all NULL elements is not considered "NULL INPUT" for a UDF, even
# though IS NULL returns true for it.
query IB
SELECT strict_fn_imp('foo', (NULL,NULL,NULL)), (NULL,NULL,NULL)::imp IS NULL
----
1 true

query I
SELECT strict_fn_imp('foo', NULL)
----
NULL


subtest return_type_assignment_casts

# Do not allow functions with return type mismatches that cannot be cast in an
Expand Down Expand Up @@ -2788,7 +2840,20 @@ SELECT oid, proname, pronamespace, proowner, prolang, proleakproof, proisstrict,
FROM pg_catalog.pg_proc WHERE proname IN ('f_93314', 'f_93314_alias', 'f_93314_comp', 'f_93314_comp_t')
ORDER BY oid;
----
100257 f_93314 105 1546506610 14 false false false v 0 100256 · {} NULL SELECT i, e FROM test.public.t_93314 ORDER BY i LIMIT 1;
100259 f_93314_alias 105 1546506610 14 false false false v 0 100258 · {} NULL SELECT i, e FROM test.public.t_93314_alias ORDER BY i LIMIT 1;
100263 f_93314_comp 105 1546506610 14 false false false v 0 100260 · {} NULL SELECT (1, 2);
100264 f_93314_comp_t 105 1546506610 14 false false false v 0 100262 · {} NULL SELECT a, c FROM test.public.t_93314_comp LIMIT 1;
100259 f_93314 105 1546506610 14 false false false v 0 100258 · {} NULL SELECT i, e FROM test.public.t_93314 ORDER BY i LIMIT 1;
100261 f_93314_alias 105 1546506610 14 false false false v 0 100260 · {} NULL SELECT i, e FROM test.public.t_93314_alias ORDER BY i LIMIT 1;
100265 f_93314_comp 105 1546506610 14 false false false v 0 100262 · {} NULL SELECT (1, 2);
100266 f_93314_comp_t 105 1546506610 14 false false false v 0 100264 · {} NULL SELECT a, c FROM test.public.t_93314_comp LIMIT 1;

# Regression test for #95240. Strict UDFs that are inlined should result in NULL
# when presented with NULL arguments.
statement ok
CREATE FUNCTION f95240(i INT) RETURNS INT STRICT LANGUAGE SQL AS 'SELECT 33';
CREATE TABLE t95240 (a INT);
INSERT INTO t95240 VALUES (1), (NULL)

query I
SELECT f95240(a) FROM t95240
----
33
NULL
3 changes: 0 additions & 3 deletions pkg/sql/opt/exec/execbuilder/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,6 @@ func (b *Builder) buildSubquery(
1, /* numStmts */
subquery.Typ,
false, /* enableStepping */
true, /* calledOnNullInput */
), nil
}

Expand Down Expand Up @@ -708,7 +707,6 @@ func (b *Builder) buildSubquery(
1, /* numStmts */
subquery.Typ,
false, /* enableStepping */
true, /* calledOnNullInput */
), nil
}

Expand Down Expand Up @@ -786,7 +784,6 @@ func (b *Builder) buildUDF(ctx *buildScalarCtx, scalar opt.ScalarExpr) (tree.Typ
len(udf.Body),
udf.Typ,
enableStepping,
udf.CalledOnNullInput,
), nil
}

Expand Down
45 changes: 45 additions & 0 deletions pkg/sql/opt/norm/testdata/rules/udf
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,48 @@ values
├── fd: ()-->(1)
└── tuple
└── const: 1

exec-ddl
CREATE FUNCTION strict_fn(i INT, t TEXT, b BOOl) RETURNS INT STRICT LANGUAGE SQL AS 'SELECT i'
----

# Strict UDFs should be folded to NULL in the presence of a constant NULL
# argument.
norm format=show-scalars
SELECT strict_fn(1, 'foo', NULL)
----
values
├── columns: strict_fn:1
├── cardinality: [1 - 1]
├── key: ()
├── fd: ()-->(1)
└── tuple
└── null

# The CASE expression used to check for NULL arguments in strict UDFs should be
# folded in the presence of all non-NULL arguments.
norm format=show-scalars
SELECT strict_fn(1, 'foo', true)
----
values
├── columns: strict_fn:5
├── cardinality: [1 - 1]
├── volatile
├── key: ()
├── fd: ()-->(5)
└── tuple
└── udf: strict_fn
├── params: i:1 t:2 b:3
├── args
│ ├── const: 1
│ ├── const: 'foo'
│ └── true
└── body
└── values
├── columns: i:4
├── outer: (1)
├── cardinality: [1 - 1]
├── key: ()
├── fd: ()-->(4)
└── tuple
└── variable: i:1
5 changes: 0 additions & 5 deletions pkg/sql/opt/ops/scalar.opt
Original file line number Diff line number Diff line change
Expand Up @@ -1261,11 +1261,6 @@ define UDFPrivate {
# immutable, or leakproof function will see a snapshot of the data as of the
# start of the statement calling the function.
Volatility Volatility

# CalledOnNullInput is true if the function should be called when any of its
# inputs are NULL. If false, the function will not be evaluated in the
# presence of NULL inputs, and will instead evaluate directly to NULL.
CalledOnNullInput bool
}

# KVOptions is a set of KVOptionItems that specify arbitrary keys and values
Expand Down
47 changes: 41 additions & 6 deletions pkg/sql/opt/optbuilder/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -731,14 +731,49 @@ func (b *Builder) buildUDF(
out = b.factory.ConstructUDF(
args,
&memo.UDFPrivate{
Name: def.Name,
Params: params,
Body: rels,
Typ: f.ResolvedType(),
Volatility: o.Volatility,
CalledOnNullInput: o.CalledOnNullInput,
Name: def.Name,
Params: params,
Body: rels,
Typ: f.ResolvedType(),
Volatility: o.Volatility,
},
)

// If the UDF is strict, it should not be invoked when any of the arguments
// are NULL. To achieve this, we wrap the UDF in a CASE expression like:
//
// CASE WHEN arg1 IS NULL OR arg2 IS NULL OR ... THEN NULL ELSE udf() END
//
if !o.CalledOnNullInput {
var anyArgIsNull opt.ScalarExpr
for i := range args {
// Note: We do NOT use a TupleIsNullExpr here if the argument is a
// tuple because a strict UDF will be called if an argument, T, is a
// tuple with all NULL elements, even though T IS NULL evaluates to
// true. For example:
//
// SELECT strict_fn(1, (NULL, NULL)) -- the UDF will be called
// SELECT (NULL, NULL) IS NULL -- returns true
//
argIsNull := b.factory.ConstructIs(args[i], memo.NullSingleton)
if anyArgIsNull == nil {
anyArgIsNull = argIsNull
continue
}
anyArgIsNull = b.factory.ConstructOr(argIsNull, anyArgIsNull)
}
out = b.factory.ConstructCase(
memo.TrueSingleton,
memo.ScalarListExpr{
b.factory.ConstructWhen(
anyArgIsNull,
b.factory.ConstructNull(f.ResolvedType()),
),
},
out,
)
}

return b.finishBuildScalar(f, out, inScope, outScope, outCol)
}

Expand Down
137 changes: 137 additions & 0 deletions pkg/sql/opt/optbuilder/testdata/udf
Original file line number Diff line number Diff line change
Expand Up @@ -1029,3 +1029,140 @@ project
└── projections
└── assignment-cast: CHAR [as=column3:3]
└── variable: s:2


# --------------------------------------------------
# UDFs that are STRICT/RETURNS NULL ON NULL INPUT.
# --------------------------------------------------

exec-ddl
CREATE FUNCTION strict_fn(i INT, t TEXT, b BOOl) RETURNS INT STRICT LANGUAGE SQL AS 'SELECT i'
----

build format=show-scalars
SELECT strict_fn(1, 'foo', false)
----
project
├── columns: strict_fn:5
├── values
│ └── tuple
└── projections
└── case [as=strict_fn:5]
├── true
├── when
│ ├── or
│ │ ├── is
│ │ │ ├── false
│ │ │ └── null
│ │ └── or
│ │ ├── is
│ │ │ ├── const: 'foo'
│ │ │ └── null
│ │ └── is
│ │ ├── const: 1
│ │ └── null
│ └── null
└── udf: strict_fn
├── params: i:1 t:2 b:3
├── args
│ ├── const: 1
│ ├── const: 'foo'
│ └── false
└── body
└── limit
├── columns: i:4
├── project
│ ├── columns: i:4
│ ├── values
│ │ └── tuple
│ └── projections
│ └── variable: i:1 [as=i:4]
└── const: 1

build format=show-scalars
SELECT strict_fn(a, b::TEXT, false) FROM abc WHERE strict_fn(a+1+2, b::TEXT, false) = 10
----
project
├── columns: strict_fn:14
├── select
│ ├── columns: a:1!null abc.b:2 c:3 crdb_internal_mvcc_timestamp:4 tableoid:5
│ ├── scan abc
│ │ └── columns: a:1!null abc.b:2 c:3 crdb_internal_mvcc_timestamp:4 tableoid:5
│ └── filters
│ └── eq
│ ├── case
│ │ ├── true
│ │ ├── when
│ │ │ ├── or
│ │ │ │ ├── is
│ │ │ │ │ ├── false
│ │ │ │ │ └── null
│ │ │ │ └── or
│ │ │ │ ├── is
│ │ │ │ │ ├── cast: STRING
│ │ │ │ │ │ └── variable: abc.b:2
│ │ │ │ │ └── null
│ │ │ │ └── is
│ │ │ │ ├── plus
│ │ │ │ │ ├── plus
│ │ │ │ │ │ ├── variable: a:1
│ │ │ │ │ │ └── const: 1
│ │ │ │ │ └── const: 2
│ │ │ │ └── null
│ │ │ └── null
│ │ └── udf: strict_fn
│ │ ├── params: i:6 t:7 b:8
│ │ ├── args
│ │ │ ├── plus
│ │ │ │ ├── plus
│ │ │ │ │ ├── variable: a:1
│ │ │ │ │ └── const: 1
│ │ │ │ └── const: 2
│ │ │ ├── cast: STRING
│ │ │ │ └── variable: abc.b:2
│ │ │ └── false
│ │ └── body
│ │ └── limit
│ │ ├── columns: i:9
│ │ ├── project
│ │ │ ├── columns: i:9
│ │ │ ├── values
│ │ │ │ └── tuple
│ │ │ └── projections
│ │ │ └── variable: i:6 [as=i:9]
│ │ └── const: 1
│ └── const: 10
└── projections
└── case [as=strict_fn:14]
├── true
├── when
│ ├── or
│ │ ├── is
│ │ │ ├── false
│ │ │ └── null
│ │ └── or
│ │ ├── is
│ │ │ ├── cast: STRING
│ │ │ │ └── variable: abc.b:2
│ │ │ └── null
│ │ └── is
│ │ ├── variable: a:1
│ │ └── null
│ └── null
└── udf: strict_fn
├── params: i:10 t:11 b:12
├── args
│ ├── variable: a:1
│ ├── cast: STRING
│ │ └── variable: abc.b:2
│ └── false
└── body
└── limit
├── columns: i:13
├── project
│ ├── columns: i:13
│ ├── values
│ │ └── tuple
│ └── projections
│ └── variable: i:10 [as=i:13]
└── const: 1
Loading