Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed May 19, 2021
1 parent 14ff3b4 commit cbdbc3d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
13 changes: 2 additions & 11 deletions integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,8 @@ def test_make_array(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).selectExpr(
'array(a, b)',
'array(b, a, null, {}, {})'.format(s1, s2)))


@pytest.mark.parametrize('data_gen', all_basic_gens + [decimal_gen_default, decimal_gen_scale_precision], ids=idfn)
def test_make_array_of_array(data_gen):
(s1, s2) = gen_scalars_for_sql(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).selectExpr(
'array(array(1, 2, 3), array(2), array(null), array())',
'array(array(), array(null), array(a, b))',
'array(array(b, a, null, {}, {}), array(a, b), array(), array(null))'.format(s1, s2)))
'array(b, a, null, {}, {})'.format(s1, s2),
'array(array(b, a, null, {}, {}), array(a), array(null))'.format(s1, s2)))


@pytest.mark.parametrize('data_gen', single_level_array_gens, ids=idfn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2366,9 +2366,19 @@ object GpuOverrides {
TypeSig.numeric + TypeSig.NULL + TypeSig.STRING +
TypeSig.BOOLEAN + TypeSig.DATE + TypeSig.TIMESTAMP +
TypeSig.ARRAY.nested(TypeSig.numeric + TypeSig.NULL + TypeSig.STRING +
TypeSig.BOOLEAN + TypeSig.DATE + TypeSig.TIMESTAMP),
TypeSig.BOOLEAN + TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.ARRAY),
TypeSig.all))),
(in, conf, p, r) => new ExprMeta[CreateArray](in, conf, p, r) {

override def tagExprForGpu(): Unit = {
wrapped.dataType match {
case ArrayType(ArrayType(ArrayType(_, _), _), _) =>
willNotWorkOnGpu("Only support to create array or array of array, Found: " +
s"${wrapped.dataType}")
case _ =>
}
}

override def convertToGpu(): GpuExpression =
GpuCreateArray(childExprs.map(_.convertToGpu()), wrapped.useStringTypeWhenEmpty)
}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
package org.apache.spark.sql.rapids

import ai.rapids.cudf.{ColumnVector, DType}

import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuExpressionsUtils, GpuScalar}
import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuExpressionsUtils}
import com.nvidia.spark.rapids.RapidsPluginImplicits.ReallyAGpuExpression

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FUNC_ALIAS
import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, NamedExpression}
Expand Down

0 comments on commit cbdbc3d

Please sign in to comment.