From 72f35b1c178edbf304a3563438ad11e3acd271bf Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 16 Apr 2015 18:39:27 +0800 Subject: [PATCH 1/2] DataFrame.withColumn can replace original column with identical column name. --- .../scala/org/apache/spark/sql/DataFrame.scala | 14 +++++++++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 4 ++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 3235f85d5bbd2..6fa52a37464ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -747,7 +747,19 @@ class DataFrame private[sql]( * Returns a new [[DataFrame]] by adding a column. * @group dfops */ - def withColumn(colName: String, col: Column): DataFrame = select(Column("*"), col.as(colName)) + def withColumn(colName: String, col: Column): DataFrame = { + val resolver = sqlContext.analyzer.resolver + val replaced = schema.exists(f => resolver(f.name, colName)) + if (replaced) { + val colNames = schema.map { field => + val name = field.name + if (resolver(name, colName)) col.as(colName) else Column(name) + } + select(colNames :_*) + } else { + select(Column("*"), col.as(colName)) + } + } /** * Returns a new [[DataFrame]] with a column renamed. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b26e22f6229fe..e22034b499703 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -457,6 +457,10 @@ class DataFrameSuite extends QueryTest { Row(key, value, key + 1) }.toSeq) assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol")) + + val df2 = TestSQLContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df3 = df2.withColumn("x", df2("x") + 1) + assert(df3.select("x").collect().toSeq === Seq(Row(2), Row(3), Row(4))) } test("withColumnRenamed") { From b539c7b7aa55c095163d06bac525d1bb90c0b734 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 17 Apr 2015 16:06:00 +0800 Subject: [PATCH 2/2] For comment. --- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e22034b499703..862f662728240 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -457,10 +457,14 @@ class DataFrameSuite extends QueryTest { Row(key, value, key + 1) }.toSeq) assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol")) + } + test("replace column using withColumn") { val df2 = TestSQLContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) - assert(df3.select("x").collect().toSeq === Seq(Row(2), Row(3), Row(4))) + checkAnswer( + df3.select("x"), + Row(2) :: Row(3) :: Row(4) :: Nil) } test("withColumnRenamed") {