diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 6a6dfbc5851b8..099d0c6a19e6a 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1622,6 +1622,48 @@ def func(iterator): df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return DataFrame(df, self) + def load(self, path=None, dataSourceName=None, schema=None, **options): + """Returns the dataset specified by the data source and a set of options + as a DataFrame. An optional schema can be applied as the schema of returned + DataFrame. If dataSourceName is not provided, the default data source configured + by spark.sql.sources.default will be used. + """ + if path is not None: + options["path"] = path + if dataSourceName is None: + dataSourceName = self._ssql_ctx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + joptions = MapConverter().convert(options, + self._sc._gateway._gateway_client) + if schema is None: + df = self._ssql_ctx.load(dataSourceName, joptions) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.load(dataSourceName, scala_datatype, joptions) + return DataFrame(df, self) + + def createExternalTable(self, tableName, path=None, dataSourceName=None, + schema=None, **options): + """Creates an external table based on the given data source and a set of options and + returns the corresponding DataFrame. + If dataSourceName is not provided, the default data source configured + by spark.sql.sources.default will be used. + """ + if path is not None: + options["path"] = path + if dataSourceName is None: + dataSourceName = self._ssql_ctx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + joptions = MapConverter().convert(options, + self._sc._gateway._gateway_client) + if schema is None: + df = self._ssql_ctx.createExternalTable(tableName, dataSourceName, joptions) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.createExternalTable(tableName, dataSourceName, scala_datatype, + joptions) + return DataFrame(df, self) + def sql(self, sqlQuery): """Return a L{DataFrame} representing the result of the given query. @@ -1889,9 +1931,57 @@ def insertInto(self, tableName, overwrite=False): """ self._jdf.insertInto(tableName, overwrite) - def saveAsTable(self, tableName): - """Creates a new table with the contents of this DataFrame.""" - self._jdf.saveAsTable(tableName) + def saveAsTable(self, tableName, dataSourceName=None, mode="append", **options): + """Creates a new table with the contents of this DataFrame based on the given data source + and a set of options. If a data source is not provided, the default data source configured + by spark.sql.sources.default will be used. + """ + if dataSourceName is None: + dataSourceName = self.sql_ctx._ssql_ctx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.ErrorIfExists + mode = mode.lower() + if mode == "append": + jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.Append + elif mode == "overwrite": + jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.Overwrite + elif mode == "ignore": + jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.Ignore + elif mode == "error": + pass + else: + raise ValueError( + "Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.") + joptions = MapConverter().convert(options, + self.sql_ctx._sc._gateway._gateway_client) + self._jdf.saveAsTable(tableName, dataSourceName, jmode, joptions) + + def save(self, path=None, dataSourceName=None, mode="append", **options): + """Saves the contents of the DataFrame to a data source based on the given data source, + the given save mode, and a set of options. If a data source is not provided, + the default data source configured by spark.sql.sources.default will be used. + """ + if path is not None: + options["path"] = path + if dataSourceName is None: + dataSourceName = self.sql_ctx._ssql_ctx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.ErrorIfExists + mode = mode.lower() + if mode == "append": + jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.Append + elif mode == "overwrite": + jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.Overwrite + elif mode == "ignore": + jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.Ignore + elif mode == "error": + pass + else: + raise ValueError( + "Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.") + joptions = MapConverter().convert(options, + self._sc._gateway._gateway_client) + self._jdf.save(dataSourceName, jmode, joptions) def schema(self): """Returns the schema of this DataFrame (represented by diff --git a/python/pyspark/sql_tests.py b/python/pyspark/sql_tests.py index d314f46e8d2d5..23b281c846e90 100644 --- a/python/pyspark/sql_tests.py +++ b/python/pyspark/sql_tests.py @@ -34,8 +34,8 @@ else: import unittest -from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ - UserDefinedType, DoubleType +from pyspark.sql import SQLContext, HiveContext, IntegerType, Row, ArrayType, StructType,\ + StructField, UserDefinedType, DoubleType from pyspark.tests import ReusedPySparkTestCase @@ -285,6 +285,38 @@ def test_aggregator(self): self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0]) self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0]) + def test_save_and_load(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.save(tmpPath, "org.apache.spark.sql.json", "error") + actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json") + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + + from pyspark.sql import StructType, StructField, StringType + schema = StructType([StructField("value", StringType(), True)]) + actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema) + self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect())) + + df.save(tmpPath, "org.apache.spark.sql.json", "overwrite") + actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json") + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + + df.save(dataSourceName="org.apache.spark.sql.json", mode="overwrite", path=tmpPath, + noUse="this options will not be used in save.") + actual = self.sqlCtx.load(dataSourceName="org.apache.spark.sql.json", path=tmpPath, + noUse="this options will not be used in load.") + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + + defaultDataSourceName = self.sqlCtx._ssql_ctx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + actual = self.sqlCtx.load(path=tmpPath) + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) @@ -294,6 +326,73 @@ def test_help_command(self): pydoc.render_doc(df.foo) pydoc.render_doc(df.take(1)) +class HiveContextSQLTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.sqlCtx = HiveContext(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + rdd = cls.sc.parallelize(cls.testData) + cls.df = cls.sqlCtx.inferSchema(rdd) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + + def test_save_and_load_table(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath) + actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, + "org.apache.spark.sql.json") + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE externalJsonTable") + + df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath) + from pyspark.sql import StructType, StructField, StringType + schema = StructType([StructField("value", StringType(), True)]) + actual = self.sqlCtx.createExternalTable("externalJsonTable", + dataSourceName="org.apache.spark.sql.json", + schema=schema, path=tmpPath, + noUse="this options will not be used") + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertTrue( + sorted(df.select("value").collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE savedJsonTable") + self.sqlCtx.sql("DROP TABLE externalJsonTable") + + defaultDataSourceName = self.sqlCtx._ssql_ctx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") + actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath) + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE savedJsonTable") + self.sqlCtx.sql("DROP TABLE externalJsonTable") + self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) if __name__ == "__main__": unittest.main()