Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-14142][SQL] Replace internal use of unionAll with union #11946

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def repartition(self, numPartitions, *cols):

>>> df.repartition(10).rdd.getNumPartitions()
10
>>> data = df.unionAll(df).repartition("age")
>>> data = df.union(df).repartition("age")
>>> data.show()
+---+-----+
|age| name|
Expand Down Expand Up @@ -919,7 +919,7 @@ def union(self, other):
This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union
(that does deduplication of elements), use this function followed by a distinct.
"""
return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx)
return DataFrame(self._jdf.union(other._jdf), self.sql_ctx)

@since(1.3)
def unionAll(self, other):
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def test_parquet_with_udt(self):
point = df1.head().point
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))

def test_unionAll_with_udt(self):
def test_union_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row1 = (1.0, ExamplePoint(1.0, 2.0))
row2 = (2.0, ExamplePoint(3.0, 4.0))
Expand All @@ -608,7 +608,7 @@ def test_unionAll_with_udt(self):
df1 = self.sqlCtx.createDataFrame([row1], schema)
df2 = self.sqlCtx.createDataFrame([row2], schema)

result = df1.unionAll(df2).orderBy("label").collect()
result = df1.union(df2).orderBy("label").collect()
self.assertEqual(
result,
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ package object dsl {

def intersect(otherPlan: LogicalPlan): LogicalPlan = Intersect(logicalPlan, otherPlan)

def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan)
def union(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan)

def generate(
generator: Generator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ class AnalysisErrorSuite extends AnalysisTest {

errorTest(
"union with unequal number of columns",
testRelation.unionAll(testRelation2),
testRelation.union(testRelation2),
"union" :: "number of columns" :: testRelation2.output.length.toString ::
testRelation.output.length.toString :: Nil)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class AnalysisSuite extends AnalysisTest {
val plan = (1 to 100)
.map(_ => testRelation)
.fold[LogicalPlan](testRelation) { (a, b) =>
a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None)))
a.select(UnresolvedStar(None)).select('a).union(b.select(UnresolvedStar(None)))
}

assertAnalysisSuccess(plan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ class PruneFiltersSuite extends PlanTest {

val query =
tr1.where('a.attr > 10)
.unionAll(tr2.where('d.attr > 10)
.unionAll(tr3.where('g.attr > 10)))
.union(tr2.where('d.attr > 10)
.union(tr3.where('g.attr > 10)))
val queryWithUselessFilter = query.where('a.attr > 10)

val optimized = Optimize.execute(queryWithUselessFilter.analyze)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,14 @@ class ConstraintPropagationSuite extends SparkFunSuite {

assert(tr1
.where('a.attr > 10)
.unionAll(tr2.where('e.attr > 10)
.unionAll(tr3.where('i.attr > 10)))
.union(tr2.where('e.attr > 10)
.union(tr3.where('i.attr > 10)))
.analyze.constraints.isEmpty)

verifyConstraints(tr1
.where('a.attr > 10)
.unionAll(tr2.where('d.attr > 10)
.unionAll(tr3.where('g.attr > 10)))
.union(tr2.where('d.attr > 10)
.union(tr3.where('g.attr > 10)))
.analyze.constraints,
ExpressionSet(Seq(resolveColumn(tr1, "a") > 10,
IsNotNull(resolveColumn(tr1, "a")))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}")
newBlocks
.map(_.toDF())
.reduceOption(_ unionAll _)
.reduceOption(_ union _)
.getOrElse {
sys.error("No data selected!")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
}

test("A cached table preserves the partitioning and ordering of its cached SparkPlan") {
val table3x = testData.unionAll(testData).unionAll(testData)
val table3x = testData.union(testData).union(testData)
table3x.registerTempTable("testData3x")

sql("SELECT key, value FROM testData3x ORDER BY key").registerTempTable("orderedTable")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
assert(splits.length == 3, "wrong number of splits")

assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
assert(splits.reduce((a, b) => a.union(b)).sort("id").collect().toList ==
data.collect().toList, "incomplete or wrong split")

val s = splits.map(_.count())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}

test("union all") {
val unionDF = testData.unionAll(testData).unionAll(testData)
.unionAll(testData).unionAll(testData)
val unionDF = testData.union(testData).union(testData)
.union(testData).union(testData)

// Before optimizer, Union should be combined.
assert(unionDF.queryExecution.analyzed.collect {
Expand All @@ -107,7 +107,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
)
}

test("unionAll should union DataFrames with UDTs (SPARK-13410)") {
test("union should union DataFrames with UDTs (SPARK-13410)") {
val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 2.0))))
val schema1 = StructType(Array(StructField("label", IntegerType, false),
StructField("point", new ExamplePointUDT(), false)))
Expand All @@ -118,7 +118,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val df2 = sqlContext.createDataFrame(rowRDD2, schema2)

checkAnswer(
df1.unionAll(df2).orderBy("label"),
df1.union(df2).orderBy("label"),
Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 4.0)))
)
}
Expand Down Expand Up @@ -636,7 +636,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val jsonDF = sqlContext.read.json(jsonDir)
assert(parquetDF.inputFiles.nonEmpty)

val unioned = jsonDF.unionAll(parquetDF).inputFiles.sorted
val unioned = jsonDF.union(parquetDF).inputFiles.sorted
val allFiles = (jsonDF.inputFiles ++ parquetDF.inputFiles).distinct.sorted
assert(unioned === allFiles)
}
Expand Down Expand Up @@ -1104,7 +1104,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
}

val union = df1.unionAll(df2)
val union = df1.union(df2)
checkAnswer(
union.filter('i < rand(7) * 10),
expected(union)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}

test("big inner join, 4 matches per row") {
val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData)
val bigData = testData.union(testData).union(testData).union(testData)
val bigDataX = bigData.as("x")
val bigDataY = bigData.as("y")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("aggregation with codegen") {
// Prepare a table that we can group some rows.
sqlContext.table("testData")
.unionAll(sqlContext.table("testData"))
.unionAll(sqlContext.table("testData"))
.union(sqlContext.table("testData"))
.union(sqlContext.table("testData"))
.registerTempTable("testData3x")

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
sqlContext
.range(0, 1000)
.selectExpr("id % 500 as key", "id as value")
.unionAll(sqlContext.range(0, 1000).selectExpr("id % 500 as key", "id as value"))
.union(sqlContext.range(0, 1000).selectExpr("id % 500 as key", "id as value"))
checkAnswer(
join,
expectedAnswer.collect())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
path.delete()

val base = sqlContext.range(100)
val df = base.unionAll(base).select($"id", lit(1).as("data"))
val df = base.union(base).select($"id", lit(1).as("data"))
df.write.partitionBy("id").save(path.getCanonicalPath)

checkAnswer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndA

// verify the append mode
df.write.mode(SaveMode.Append).json(path.toString)
val df2 = df.unionAll(df)
val df2 = df.union(df)
df2.registerTempTable("jsonTable2")

checkLoad(df2, "jsonTable2")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public void saveTableAndQueryIt() {

@Test
public void testUDAF() {
Dataset<Row> df = hc.range(0, 100).unionAll(hc.range(0, 100)).select(col("id").as("value"));
Dataset<Row> df = hc.range(0, 100).union(hc.range(0, 100)).select(col("id").as("value"));
UserDefinedAggregateFunction udaf = new MyDoubleSum();
UserDefinedAggregateFunction registeredUDAF = hc.udf().register("mydoublesum", udaf);
// Create Columns for the UDAF. For now, callUDF does not take an argument to specific if
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class CachedTableSuite extends QueryTest with TestHiveSingleton {
assertCached(table("refreshTable"))
checkAnswer(
table("refreshTable"),
table("src").unionAll(table("src")).collect())
table("src").union(table("src")).collect())

// Drop the table and create it again.
sql("DROP TABLE refreshTable")
Expand All @@ -198,7 +198,7 @@ class CachedTableSuite extends QueryTest with TestHiveSingleton {
sql("REFRESH TABLE refreshTable")
checkAnswer(
table("refreshTable"),
table("src").unionAll(table("src")).collect())
table("src").union(table("src")).collect())
// It is not cached.
assert(!isCached("refreshTable"), "refreshTable should not be cached.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
df.write.mode(SaveMode.Append).saveAsTable("t")
assert(sqlContext.tableNames().contains("t"))
checkAnswer(sqlContext.table("t"), df.unionAll(df))
checkAnswer(sqlContext.table("t"), df.union(df))
}

assert(sqlContext.tableNames(db).contains("t"))
checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df))
checkAnswer(sqlContext.table(s"$db.t"), df.union(df))

checkTablePath(db, "t")
}
Expand All @@ -128,7 +128,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t")
df.write.mode(SaveMode.Append).saveAsTable(s"$db.t")
assert(sqlContext.tableNames(db).contains("t"))
checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df))
checkAnswer(sqlContext.table(s"$db.t"), df.union(df))

checkTablePath(db, "t")
}
Expand All @@ -141,7 +141,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
assert(sqlContext.tableNames().contains("t"))

df.write.insertInto(s"$db.t")
checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df))
checkAnswer(sqlContext.table(s"$db.t"), df.union(df))
}
}
}
Expand All @@ -156,7 +156,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
assert(sqlContext.tableNames(db).contains("t"))

df.write.insertInto(s"$db.t")
checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df))
checkAnswer(sqlContext.table(s"$db.t"), df.union(df))
}
}

Expand Down Expand Up @@ -220,7 +220,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
hiveContext.refreshTable("t")
checkAnswer(
sqlContext.table("t"),
df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2))))
df.withColumn("p", lit(1)).union(df.withColumn("p", lit(2))))
}
}
}
Expand Down Expand Up @@ -252,7 +252,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
hiveContext.refreshTable(s"$db.t")
checkAnswer(
sqlContext.table(s"$db.t"),
df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2))))
df.withColumn("p", lit(1)).union(df.withColumn("p", lit(2))))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
fs.delete(commonSummaryPath, true)

df.write.mode(SaveMode.Append).parquet(path)
checkAnswer(sqlContext.read.parquet(path), df.unionAll(df))
checkAnswer(sqlContext.read.parquet(path), df.union(df))

assert(fs.exists(summaryPath))
assert(fs.exists(commonSummaryPath))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
p2 <- Seq("foo", "bar")
} yield (i, s"val_$i", 2, p2)).toDF("a", "b", "p1", "p2")

lazy val partitionedTestDF = partitionedTestDF1.unionAll(partitionedTestDF2)
lazy val partitionedTestDF = partitionedTestDF1.union(partitionedTestDF2)

def checkQueries(df: DataFrame): Unit = {
// Selects everything
Expand Down Expand Up @@ -191,7 +191,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
sqlContext.read.format(dataSourceName)
.option("dataSchema", dataSchema.json)
.load(file.getCanonicalPath).orderBy("a"),
testDF.unionAll(testDF).orderBy("a").collect())
testDF.union(testDF).orderBy("a").collect())
}
}

Expand Down Expand Up @@ -268,7 +268,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
sqlContext.read.format(dataSourceName)
.option("dataSchema", dataSchema.json)
.load(file.getCanonicalPath),
partitionedTestDF.unionAll(partitionedTestDF).collect())
partitionedTestDF.union(partitionedTestDF).collect())
}
}

Expand Down Expand Up @@ -332,7 +332,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
testDF.write.format(dataSourceName).mode(SaveMode.Append).saveAsTable("t")

withTable("t") {
checkAnswer(sqlContext.table("t"), testDF.unionAll(testDF).orderBy("a").collect())
checkAnswer(sqlContext.table("t"), testDF.union(testDF).orderBy("a").collect())
}
}

Expand Down Expand Up @@ -415,7 +415,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.saveAsTable("t")

withTable("t") {
checkAnswer(sqlContext.table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect())
checkAnswer(sqlContext.table("t"), partitionedTestDF.union(partitionedTestDF).collect())
}
}

Expand Down Expand Up @@ -625,7 +625,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.format(dataSourceName)
.option("dataSchema", df.schema.json)
.load(dir.getCanonicalPath),
df.unionAll(df))
df.union(df))

// This will fail because AlwaysFailOutputCommitter is used when we do append.
intercept[Exception] {
Expand Down