Skip to content

Commit

Permalink
Better toString, factories for AttributeSet.
Browse files Browse the repository at this point in the history
  • Loading branch information
marmbrus committed Oct 5, 2014
1 parent cf1d32e commit fbeab54
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,26 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.Star

protected class AttributeEquals(val a: Attribute) {
override def hashCode() = a.exprId.hashCode()
override def equals(other: Any) = other match {
case otherReference: AttributeEquals => a.exprId == otherReference.a.exprId
case otherAttribute => false
override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match {
case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId
case (a1, a2) => a1 == a2
}
}

object AttributeSet {
/** Constructs a new [[AttributeSet]] given a sequence of [[Attribute Attributes]]. */
def apply(baseSet: Seq[Attribute]) = {
new AttributeSet(baseSet.map(new AttributeEquals(_)).toSet)
}
def apply(a: Attribute) =
new AttributeSet(Set(new AttributeEquals(a)))

/** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */
def apply(baseSet: Seq[Expression]) =
new AttributeSet(
baseSet
.flatMap(_.references)
.map(new AttributeEquals(_)).toSet)
}

/**
Expand Down Expand Up @@ -103,4 +110,6 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
// We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all
// sorts of things in its closure.
override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq

override def toString = "{" + baseSet.map(_.a).mkString(", ") + "}"
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ abstract class NamedExpression extends Expression {
abstract class Attribute extends NamedExpression {
self: Product =>

override def references = AttributeSet(this)

def withNullability(newNullability: Boolean): Attribute
def withQualifiers(newQualifiers: Seq[String]): Attribute
def withName(newName: String): Attribute
Expand Down Expand Up @@ -116,8 +118,6 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
(val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil)
extends Attribute with trees.LeafNode[Expression] {

override def references = AttributeSet(this :: Nil)

override def equals(other: Any) = other match {
case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType
case _ => false
Expand Down

0 comments on commit fbeab54

Please sign in to comment.