Skip to content

Commit

Permalink
[SPARK-6994][SQL] Add fieldIndex to schema (StructType)
Browse files Browse the repository at this point in the history
  • Loading branch information
vidmantas zemleris committed Apr 19, 2015
1 parent 327ebf0 commit 9564ebb
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru

private lazy val fieldNamesSet: Set[String] = fieldNames.toSet
private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap

/**
* Extracts a [[StructField]] of the given name. If the [[StructType]] object does not
Expand All @@ -1049,6 +1050,14 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
StructType(fields.filter(f => names.contains(f.name)))
}

/**
* Returns index of a given field
*/
def fieldIndex(name: String): Int = {
nameToIndex.getOrElse(name,
throw new IllegalArgumentException(s"""Field "$name" does not exist."""))
}

protected[sql] def toAttributes: Seq[AttributeReference] =
map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ class DataTypeSuite extends FunSuite {
}
}

test("extract field index from a StructType") {
val struct = StructType(
StructField("a", LongType) ::
StructField("b", FloatType) :: Nil)

assert(struct.fieldIndex("a") === 0)
assert(struct.fieldIndex("b") === 1)

intercept[IllegalArgumentException] {
struct.fieldIndex("non_existent")
}
}

def checkDataTypeJsonRepr(dataType: DataType): Unit = {
test(s"JSON - $dataType") {
assert(DataType.fromJson(dataType.json) === dataType)
Expand Down

0 comments on commit 9564ebb

Please sign in to comment.