diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index a6a14df6a33ea..fb1c6182cf956 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -79,12 +79,12 @@ case class Filter(condition: Expression, child: SparkPlan) // Split out all the IsNotNulls from condition. private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition { - case IsNotNull(a) if child.output.exists(_.semanticEquals(a)) => true + case IsNotNull(a: NullIntolerant) if a.references.subsetOf(child.outputSet) => true case _ => false } // The columns that will filtered out by `IsNotNull` could be considered as not nullable. - private val notNullAttributes = notNullPreds.flatMap(_.references) + private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId) // Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate // all the variables at the beginning to take advantage of short circuiting. @@ -92,7 +92,7 @@ case class Filter(condition: Expression, child: SparkPlan) override def output: Seq[Attribute] = { child.output.map { a => - if (a.nullable && notNullAttributes.exists(_.semanticEquals(a))) { + if (a.nullable && notNullAttributes.contains(a.exprId)) { a.withNullability(false) } else { a @@ -179,7 +179,7 @@ case class Filter(condition: Expression, child: SparkPlan) // Reset the isNull to false for the not-null columns, then the followed operators could // generate better code (remove dead branches). val resultVars = input.zipWithIndex.map { case (ev, i) => - if (notNullAttributes.exists(_.semanticEquals(child.output(i)))) { + if (notNullAttributes.contains(child.output(i).exprId)) { ev.isNull = "false" } ev