Skip to content

Commit

Permalink
[SPARK-50619][SQL] Refactor VariantGet.cast to pack the cast arguments
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

As the title. It refactors the code for simplification.

### Why are the changes needed?

The refactor will make it simpler for the shredded user to use `VariantGet.cast`.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #49239 from chenhao-db/VariantCastArgs.

Authored-by: Chenhao Li <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
chenhao-db authored and cloud-fan committed Dec 19, 2024
1 parent 7f6d554 commit 3a61eef
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,11 @@ case class Cast(
}
}

private lazy val castArgs = variant.VariantCastArgs(
evalMode != EvalMode.TRY,
timeZoneId,
zoneId)

def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType)

// [[func]] assumes the input is no longer null because eval already does the null check.
Expand Down Expand Up @@ -1127,7 +1132,7 @@ case class Cast(
_ => throw QueryExecutionErrors.cannotCastFromNullTypeError(to)
} else if (from.isInstanceOf[VariantType]) {
buildCast[VariantVal](_, v => {
variant.VariantGet.cast(v, to, evalMode != EvalMode.TRY, timeZoneId, zoneId)
variant.VariantGet.cast(v, to, castArgs)
})
} else {
to match {
Expand Down Expand Up @@ -1225,12 +1230,10 @@ case class Cast(
case _ if from.isInstanceOf[VariantType] => (c, evPrim, evNull) =>
val tmp = ctx.freshVariable("tmp", classOf[Object])
val dataTypeArg = ctx.addReferenceObj("dataType", to)
val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId)
val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
val failOnError = evalMode != EvalMode.TRY
val castArgsArg = ctx.addReferenceObj("castArgs", castArgs)
val cls = classOf[variant.VariantGet].getName
code"""
Object $tmp = $cls.cast($c, $dataTypeArg, $failOnError, $zoneStrArg, $zoneIdArg);
Object $tmp = $cls.cast($c, $dataTypeArg, $castArgsArg);
if ($tmp == null) {
$evNull = true;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,30 +278,28 @@ case class VariantGet(
override def nullable: Boolean = true
override def nullIntolerant: Boolean = true

private lazy val castArgs = VariantCastArgs(
failOnError,
timeZoneId,
zoneId)

protected override def nullSafeEval(input: Any, path: Any): Any = {
VariantGet.variantGet(
input.asInstanceOf[VariantVal],
parsedPath,
dataType,
failOnError,
timeZoneId,
zoneId)
VariantGet.variantGet(input.asInstanceOf[VariantVal], parsedPath, dataType, castArgs)
}

protected override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childCode = child.genCode(ctx)
val tmp = ctx.freshVariable("tmp", classOf[Object])
val parsedPathArg = ctx.addReferenceObj("parsedPath", parsedPath)
val dataTypeArg = ctx.addReferenceObj("dataType", dataType)
val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId)
val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
val castArgsArg = ctx.addReferenceObj("castArgs", castArgs)
val code = code"""
${childCode.code}
boolean ${ev.isNull} = ${childCode.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
Object $tmp = org.apache.spark.sql.catalyst.expressions.variant.VariantGet.variantGet(
${childCode.value}, $parsedPathArg, $dataTypeArg, $failOnError, $zoneStrArg, $zoneIdArg);
${childCode.value}, $parsedPathArg, $dataTypeArg, $castArgsArg);
if ($tmp == null) {
${ev.isNull} = true;
} else {
Expand All @@ -323,6 +321,12 @@ case class VariantGet(
override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId = Option(timeZoneId))
}

// Several parameters used by `VariantGet.cast`. Packed together to simplify parameter passing.
case class VariantCastArgs(
failOnError: Boolean,
zoneStr: Option[String],
zoneId: ZoneId)

case object VariantGet {
/**
* Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset
Expand All @@ -347,9 +351,7 @@ case object VariantGet {
input: VariantVal,
parsedPath: Array[VariantPathParser.PathSegment],
dataType: DataType,
failOnError: Boolean,
zoneStr: Option[String],
zoneId: ZoneId): Any = {
castArgs: VariantCastArgs): Any = {
var v = new Variant(input.getValue, input.getMetadata)
for (path <- parsedPath) {
v = path match {
Expand All @@ -359,21 +361,16 @@ case object VariantGet {
}
if (v == null) return null
}
VariantGet.cast(v, dataType, failOnError, zoneStr, zoneId)
VariantGet.cast(v, dataType, castArgs)
}

/**
* A simple wrapper of the `cast` function that takes `Variant` rather than `VariantVal`. The
* `Cast` expression uses it and makes the implementation simpler.
*/
def cast(
input: VariantVal,
dataType: DataType,
failOnError: Boolean,
zoneStr: Option[String],
zoneId: ZoneId): Any = {
def cast(input: VariantVal, dataType: DataType, castArgs: VariantCastArgs): Any = {
val v = new Variant(input.getValue, input.getMetadata)
VariantGet.cast(v, dataType, failOnError, zoneStr, zoneId)
VariantGet.cast(v, dataType, castArgs)
}

/**
Expand All @@ -383,15 +380,10 @@ case object VariantGet {
* "hello" to int). If the cast fails, throw an exception when `failOnError` is true, or return a
* SQL NULL when it is false.
*/
def cast(
v: Variant,
dataType: DataType,
failOnError: Boolean,
zoneStr: Option[String],
zoneId: ZoneId): Any = {
def cast(v: Variant, dataType: DataType, castArgs: VariantCastArgs): Any = {
def invalidCast(): Any = {
if (failOnError) {
throw QueryExecutionErrors.invalidVariantCast(v.toJson(zoneId), dataType)
if (castArgs.failOnError) {
throw QueryExecutionErrors.invalidVariantCast(v.toJson(castArgs.zoneId), dataType)
} else {
null
}
Expand All @@ -411,7 +403,7 @@ case object VariantGet {
val input = variantType match {
case Type.OBJECT | Type.ARRAY =>
return if (dataType.isInstanceOf[StringType]) {
UTF8String.fromString(v.toJson(zoneId))
UTF8String.fromString(v.toJson(castArgs.zoneId))
} else {
invalidCast()
}
Expand Down Expand Up @@ -457,7 +449,7 @@ case object VariantGet {
}
case _ =>
if (Cast.canAnsiCast(input.dataType, dataType)) {
val result = Cast(input, dataType, zoneStr, EvalMode.TRY).eval()
val result = Cast(input, dataType, castArgs.zoneStr, EvalMode.TRY).eval()
if (result == null) invalidCast() else result
} else {
invalidCast()
Expand All @@ -468,7 +460,7 @@ case object VariantGet {
val size = v.arraySize()
val array = new Array[Any](size)
for (i <- 0 until size) {
array(i) = cast(v.getElementAtIndex(i), elementType, failOnError, zoneStr, zoneId)
array(i) = cast(v.getElementAtIndex(i), elementType, castArgs)
}
new GenericArrayData(array)
} else {
Expand All @@ -482,7 +474,7 @@ case object VariantGet {
for (i <- 0 until size) {
val field = v.getFieldAtIndex(i)
keyArray(i) = UTF8String.fromString(field.key)
valueArray(i) = cast(field.value, valueType, failOnError, zoneStr, zoneId)
valueArray(i) = cast(field.value, valueType, castArgs)
}
ArrayBasedMapData(keyArray, valueArray)
} else {
Expand All @@ -495,8 +487,7 @@ case object VariantGet {
val field = v.getFieldAtIndex(i)
st.getFieldIndex(field.key) match {
case Some(idx) =>
row.update(idx,
cast(field.value, fields(idx).dataType, failOnError, zoneStr, zoneId))
row.update(idx, cast(field.value, fields(idx).dataType, castArgs))
case _ =>
}
}
Expand Down

0 comments on commit 3a61eef

Please sign in to comment.