diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaUDF.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaUDF.scala index fb5647b82fb..4c26595ead8 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaUDF.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaUDF.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.delta import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedFunction} import org.apache.spark.sql.functions.udf @@ -81,8 +82,8 @@ object DeltaUDF { orElse: => UserDefinedFunction): UserDefinedFunction = { if (SparkSession.active.sessionState.conf .getConf(DeltaSQLConf.INTERNAL_UDF_OPTIMIZATION_ENABLED)) { - val inputEncoders = template.inputEncoders.map(_.map(_.copy())) - val outputEncoder = template.outputEncoder.map(_.copy()) + val inputEncoders = template.inputEncoders.map(_.map(e => encoderFor(e))) + val outputEncoder = template.outputEncoder.map(e => encoderFor(e)) template.copy(f = f, inputEncoders = inputEncoders, outputEncoder = outputEncoder) } else { orElse