From 12c86df397799c1ebadd45ff1da7f40acad89b71 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 11 Feb 2015 17:51:41 -0800 Subject: [PATCH] Add tables() to SQLContext to return a DataFrame containing existing tables. --- python/pyspark/sql/context.py | 18 ++++ .../spark/sql/catalyst/analysis/Catalog.scala | 36 +++++++ .../org/apache/spark/sql/SQLContext.scala | 18 ++++ .../apache/spark/sql/ListTablesSuite.scala | 71 +++++++++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 5 + .../spark/sql/hive/ListTablesSuite.scala | 99 +++++++++++++++++++ 6 files changed, 247 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index db4bcbece2c1b..9ab86a773aaff 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -621,6 +621,24 @@ def table(self, tableName): """ return DataFrame(self._ssql_ctx.table(tableName), self) + def tables(self, dbName=None): + """Returns a DataFrame containing names of table in the given database. + + If `dbName` is `None`, the database will be the current database. + + The returned DataFrame has two columns, tableName and isTemporary + (a column with BooleanType indicating if a table is a temporary one or not). + + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.tables() + >>> df2.first() + Row(tableName=u'table1', isTemporary=True) + """ + if dbName is None: + return DataFrame(self._ssql_ctx.tables(), self) + else: + return DataFrame(self._ssql_ctx.tables(dbName), self) + def cacheTable(self, tableName): """Caches the specified table in-memory.""" self._ssql_ctx.cacheTable(tableName) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index df8d03b86c533..d816235c36f51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -34,6 +34,12 @@ trait Catalog { tableIdentifier: Seq[String], alias: Option[String] = None): LogicalPlan + /** + * Returns names and flags indicating if a table is temporary or not of all tables in the + * database identified by `databaseIdentifier`. + */ + def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)] + def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit def unregisterTable(tableIdentifier: Seq[String]): Unit @@ -60,6 +66,10 @@ trait Catalog { protected def getDBTable(tableIdent: Seq[String]) : (Option[String], String) = { (tableIdent.lift(tableIdent.size - 2), tableIdent.last) } + + protected def getDBName(databaseIdentifier: Seq[String]): Option[String] = { + databaseIdentifier.lift(databaseIdentifier.size - 1) + } } class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { @@ -101,6 +111,12 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { // properly qualified with this alias. alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) } + + override def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)] = { + tables.map { + case (name, _) => (name, true) + }.toSeq + } } /** @@ -137,6 +153,22 @@ trait OverrideCatalog extends Catalog { withAlias.getOrElse(super.lookupRelation(tableIdentifier, alias)) } + abstract override def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)] = { + val dbName = getDBName(databaseIdentifier) + val temporaryTables = overrides.filter { + // If a temporary table does not have an associated database, we should return its name. + case ((None, _), _) => true + // If a temporary table does have an associated database, we should return it if the database + // matches the given database name. + case ((db: Some[String], _), _) if db == dbName => true + case _ => false + }.map { + case ((_, tableName), _) => (tableName, true) + }.toSeq + + temporaryTables ++ super.getTables(databaseIdentifier) + } + override def registerTable( tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { @@ -172,6 +204,10 @@ object EmptyCatalog extends Catalog { throw new UnsupportedOperationException } + override def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)] = { + throw new UnsupportedOperationException + } + def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { throw new UnsupportedOperationException } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index fd121ce05698c..9183c1f5f6227 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -734,6 +734,24 @@ class SQLContext(@transient val sparkContext: SparkContext) def table(tableName: String): DataFrame = DataFrame(this, catalog.lookupRelation(Seq(tableName))) + /** + * Returns a [[DataFrame]] containing names of existing tables in the current database. + * The returned DataFrame has two columns, tableName and isTemporary (a column with BooleanType + * indicating if a table is a temporary one or not). + */ + def tables(databaseName: String): DataFrame = { + createDataFrame(catalog.getTables(Seq(databaseName))).toDataFrame("tableName", "isTemporary") + } + + /** + * Returns a [[DataFrame]] containing names of existing tables in the given database. + * The returned DataFrame has two columns, tableName and isTemporary (a column with BooleanType + * indicating if a table is a temporary one or not). + */ + def tables(): DataFrame = { + createDataFrame(catalog.getTables(Seq.empty[String])).toDataFrame("tableName", "isTemporary") + } + protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala new file mode 100644 index 0000000000000..89906312478cb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -0,0 +1,71 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} + +class ListTablesSuite extends QueryTest with BeforeAndAfterAll { + + import org.apache.spark.sql.test.TestSQLContext.implicits._ + + val df = + sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value") + + override def beforeAll(): Unit = { + (1 to 10).foreach(i => df.registerTempTable(s"table$i")) + } + + override def afterAll(): Unit = { + catalog.unregisterAllTables() + } + + test("get All Tables") { + checkAnswer(tables(), (1 to 10).map(i => Row(s"table$i", true))) + } + + test("getting All Tables with a database name has not impact on returned table names") { + checkAnswer(tables("DB"), (1 to 10).map(i => Row(s"table$i", true))) + } + + test("query the returned DataFrame of tables") { + val tableDF = tables() + val schema = StructType( + StructField("tableName", StringType, true) :: + StructField("isTemporary", BooleanType, false) :: Nil) + assert(schema === tableDF.schema) + + checkAnswer( + tableDF.select("tableName"), + (1 to 10).map(i => Row(s"table$i")) + ) + + tableDF.registerTempTable("tables") + checkAnswer( + sql("SELECT isTemporary, tableName from tables WHERE isTemporary"), + (1 to 10).map(i => Row(true, s"table$i")) + ) + checkAnswer( + tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + Row("tables", true)) + dropTempTable("tables") + } +} \ No newline at end of file diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index c78369d12cf55..7be0eeb44d486 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -198,6 +198,11 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } } + override def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)] = { + val dbName = getDBName(databaseIdentifier).getOrElse(hive.sessionState.getCurrentDatabase) + client.getAllTables(dbName).map(tableName => (tableName, false)) + } + /** * Create table with specified database, table name, table description and schema * @param databaseName Database Name diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala new file mode 100644 index 0000000000000..236981f6a3d7e --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -0,0 +1,99 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.hive + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} + +class ListTablesSuite extends QueryTest with BeforeAndAfterAll { + + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + val sqlContext = TestHive + val df = + sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value") + + override def beforeAll(): Unit = { + // The catalog in HiveContext is a case insensitive one. + (1 to 10).foreach(i => catalog.registerTable(Seq(s"Table$i"), df.logicalPlan)) + (1 to 10).foreach(i => catalog.registerTable(Seq("db1", s"db1TempTable$i"), df.logicalPlan)) + (1 to 10).foreach { + i => sql(s"CREATE TABLE hivetable$i (key int, value string)") + } + sql("CREATE DATABASE IF NOT EXISTS db1") + (1 to 10).foreach { + i => sql(s"CREATE TABLE db1.db1hivetable$i (key int, value string)") + } + } + + override def afterAll(): Unit = { + catalog.unregisterAllTables() + (1 to 10).foreach { + i => sql(s"DROP TABLE IF EXISTS hivetable$i") + } + (1 to 10).foreach { + i => sql(s"DROP TABLE IF EXISTS db1.db1hivetable$i") + } + sql("DROP DATABASE IF EXISTS db1") + } + + test("get All Tables of current database") { + // We are using default DB. + val expectedTables = + (1 to 10).map(i => Row(s"table$i", true)) ++ + (1 to 10).map(i => Row(s"hivetable$i", false)) + checkAnswer(tables(), expectedTables) + } + + test("getting All Tables with a database name has not impact on returned table names") { + val expectedTables = + // We are expecting to see Table1 to Table10 since there is no database associated with them. + (1 to 10).map(i => Row(s"table$i", true)) ++ + (1 to 10).map(i => Row(s"db1temptable$i", true)) ++ + (1 to 10).map(i => Row(s"db1hivetable$i", false)) + checkAnswer(tables("db1"), expectedTables) + } + + test("query the returned DataFrame of tables") { + val tableDF = tables() + val schema = StructType( + StructField("tableName", StringType, true) :: + StructField("isTemporary", BooleanType, false) :: Nil) + assert(schema === tableDF.schema) + + checkAnswer( + tableDF.filter("NOT isTemporary").select("tableName"), + (1 to 10).map(i => Row(s"hivetable$i")) + ) + + tableDF.registerTempTable("tables") + checkAnswer( + sql("SELECT isTemporary, tableName from tables WHERE isTemporary"), + (1 to 10).map(i => Row(true, s"table$i")) + ) + checkAnswer( + tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + Row("tables", true)) + dropTempTable("tables") + } +} \ No newline at end of file