Skip to content

Commit

Permalink
Add tables() to SQLContext to return a DataFrame containing existing …
Browse files Browse the repository at this point in the history
…tables.
  • Loading branch information
yhuai committed Feb 12, 2015
1 parent 44b2311 commit 12c86df
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 0 deletions.
18 changes: 18 additions & 0 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,24 @@ def table(self, tableName):
"""
return DataFrame(self._ssql_ctx.table(tableName), self)

def tables(self, dbName=None):
"""Returns a DataFrame containing names of table in the given database.
If `dbName` is `None`, the database will be the current database.
The returned DataFrame has two columns, tableName and isTemporary
(a column with BooleanType indicating if a table is a temporary one or not).
>>> sqlCtx.registerRDDAsTable(df, "table1")
>>> df2 = sqlCtx.tables()
>>> df2.first()
Row(tableName=u'table1', isTemporary=True)
"""
if dbName is None:
return DataFrame(self._ssql_ctx.tables(), self)
else:
return DataFrame(self._ssql_ctx.tables(dbName), self)

def cacheTable(self, tableName):
"""Caches the specified table in-memory."""
self._ssql_ctx.cacheTable(tableName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ trait Catalog {
tableIdentifier: Seq[String],
alias: Option[String] = None): LogicalPlan

/**
* Returns names and flags indicating if a table is temporary or not of all tables in the
* database identified by `databaseIdentifier`.
*/
def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)]

def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit

def unregisterTable(tableIdentifier: Seq[String]): Unit
Expand All @@ -60,6 +66,10 @@ trait Catalog {
protected def getDBTable(tableIdent: Seq[String]) : (Option[String], String) = {
(tableIdent.lift(tableIdent.size - 2), tableIdent.last)
}

protected def getDBName(databaseIdentifier: Seq[String]): Option[String] = {
databaseIdentifier.lift(databaseIdentifier.size - 1)
}
}

class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
Expand Down Expand Up @@ -101,6 +111,12 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
// properly qualified with this alias.
alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers)
}

override def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)] = {
tables.map {
case (name, _) => (name, true)
}.toSeq
}
}

/**
Expand Down Expand Up @@ -137,6 +153,22 @@ trait OverrideCatalog extends Catalog {
withAlias.getOrElse(super.lookupRelation(tableIdentifier, alias))
}

abstract override def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)] = {
val dbName = getDBName(databaseIdentifier)
val temporaryTables = overrides.filter {
// If a temporary table does not have an associated database, we should return its name.
case ((None, _), _) => true
// If a temporary table does have an associated database, we should return it if the database
// matches the given database name.
case ((db: Some[String], _), _) if db == dbName => true
case _ => false
}.map {
case ((_, tableName), _) => (tableName, true)
}.toSeq

temporaryTables ++ super.getTables(databaseIdentifier)
}

override def registerTable(
tableIdentifier: Seq[String],
plan: LogicalPlan): Unit = {
Expand Down Expand Up @@ -172,6 +204,10 @@ object EmptyCatalog extends Catalog {
throw new UnsupportedOperationException
}

override def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)] = {
throw new UnsupportedOperationException
}

def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = {
throw new UnsupportedOperationException
}
Expand Down
18 changes: 18 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,24 @@ class SQLContext(@transient val sparkContext: SparkContext)
def table(tableName: String): DataFrame =
DataFrame(this, catalog.lookupRelation(Seq(tableName)))

/**
* Returns a [[DataFrame]] containing names of existing tables in the current database.
* The returned DataFrame has two columns, tableName and isTemporary (a column with BooleanType
* indicating if a table is a temporary one or not).
*/
def tables(databaseName: String): DataFrame = {
createDataFrame(catalog.getTables(Seq(databaseName))).toDataFrame("tableName", "isTemporary")
}

/**
* Returns a [[DataFrame]] containing names of existing tables in the given database.
* The returned DataFrame has two columns, tableName and isTemporary (a column with BooleanType
* indicating if a table is a temporary one or not).
*/
def tables(): DataFrame = {
createDataFrame(catalog.getTables(Seq.empty[String])).toDataFrame("tableName", "isTemporary")
}

protected[sql] class SparkPlanner extends SparkStrategies {
val sparkContext: SparkContext = self.sparkContext

Expand Down
71 changes: 71 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}

class ListTablesSuite extends QueryTest with BeforeAndAfterAll {

import org.apache.spark.sql.test.TestSQLContext.implicits._

val df =
sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value")

override def beforeAll(): Unit = {
(1 to 10).foreach(i => df.registerTempTable(s"table$i"))
}

override def afterAll(): Unit = {
catalog.unregisterAllTables()
}

test("get All Tables") {
checkAnswer(tables(), (1 to 10).map(i => Row(s"table$i", true)))
}

test("getting All Tables with a database name has not impact on returned table names") {
checkAnswer(tables("DB"), (1 to 10).map(i => Row(s"table$i", true)))
}

test("query the returned DataFrame of tables") {
val tableDF = tables()
val schema = StructType(
StructField("tableName", StringType, true) ::
StructField("isTemporary", BooleanType, false) :: Nil)
assert(schema === tableDF.schema)

checkAnswer(
tableDF.select("tableName"),
(1 to 10).map(i => Row(s"table$i"))
)

tableDF.registerTempTable("tables")
checkAnswer(
sql("SELECT isTemporary, tableName from tables WHERE isTemporary"),
(1 to 10).map(i => Row(true, s"table$i"))
)
checkAnswer(
tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
Row("tables", true))
dropTempTable("tables")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
}
}

override def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)] = {
val dbName = getDBName(databaseIdentifier).getOrElse(hive.sessionState.getCurrentDatabase)
client.getAllTables(dbName).map(tableName => (tableName, false))
}

/**
* Create table with specified database, table name, table description and schema
* @param databaseName Database Name
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive

import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}

class ListTablesSuite extends QueryTest with BeforeAndAfterAll {

import org.apache.spark.sql.hive.test.TestHive.implicits._

val sqlContext = TestHive
val df =
sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value")

override def beforeAll(): Unit = {
// The catalog in HiveContext is a case insensitive one.
(1 to 10).foreach(i => catalog.registerTable(Seq(s"Table$i"), df.logicalPlan))
(1 to 10).foreach(i => catalog.registerTable(Seq("db1", s"db1TempTable$i"), df.logicalPlan))
(1 to 10).foreach {
i => sql(s"CREATE TABLE hivetable$i (key int, value string)")
}
sql("CREATE DATABASE IF NOT EXISTS db1")
(1 to 10).foreach {
i => sql(s"CREATE TABLE db1.db1hivetable$i (key int, value string)")
}
}

override def afterAll(): Unit = {
catalog.unregisterAllTables()
(1 to 10).foreach {
i => sql(s"DROP TABLE IF EXISTS hivetable$i")
}
(1 to 10).foreach {
i => sql(s"DROP TABLE IF EXISTS db1.db1hivetable$i")
}
sql("DROP DATABASE IF EXISTS db1")
}

test("get All Tables of current database") {
// We are using default DB.
val expectedTables =
(1 to 10).map(i => Row(s"table$i", true)) ++
(1 to 10).map(i => Row(s"hivetable$i", false))
checkAnswer(tables(), expectedTables)
}

test("getting All Tables with a database name has not impact on returned table names") {
val expectedTables =
// We are expecting to see Table1 to Table10 since there is no database associated with them.
(1 to 10).map(i => Row(s"table$i", true)) ++
(1 to 10).map(i => Row(s"db1temptable$i", true)) ++
(1 to 10).map(i => Row(s"db1hivetable$i", false))
checkAnswer(tables("db1"), expectedTables)
}

test("query the returned DataFrame of tables") {
val tableDF = tables()
val schema = StructType(
StructField("tableName", StringType, true) ::
StructField("isTemporary", BooleanType, false) :: Nil)
assert(schema === tableDF.schema)

checkAnswer(
tableDF.filter("NOT isTemporary").select("tableName"),
(1 to 10).map(i => Row(s"hivetable$i"))
)

tableDF.registerTempTable("tables")
checkAnswer(
sql("SELECT isTemporary, tableName from tables WHERE isTemporary"),
(1 to 10).map(i => Row(true, s"table$i"))
)
checkAnswer(
tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
Row("tables", true))
dropTempTable("tables")
}
}

0 comments on commit 12c86df

Please sign in to comment.