Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-43046] [SS] [Connect] Implemented Python API dropDuplicatesWithinWatermark for Spark Connect #40834

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ message Relation {
WithWatermark with_watermark = 33;
ApplyInPandasWithState apply_in_pandas_with_state = 34;
HtmlString html_string = 35;
Deduplicate deduplicate_within_watermark = 36;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not required, right? The flag in Deduplicate will indicate this.


// NA functions
NAFill fill_na = 90;
Expand Down Expand Up @@ -362,6 +363,9 @@ message Deduplicate {
//
// This field does not co-use with `column_names`.
optional bool all_columns_as_keys = 3;

// (Optional) Deduplicate within the time range of watermark.
optional bool within_watermark = 4;
}

// A relation that does not need to be qualified by name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,17 @@ package object dsl {
.addAllColumnNames(colNames.asJava))
.build()

def deduplicateWithinWatermark(colNames: Seq[String]): Relation =
Relation
.newBuilder()
.setDeduplicateWithinWatermark(
Deduplicate
.newBuilder()
.setInput(logicalPlan)
.addAllColumnNames(colNames.asJava)
.setWithinWatermark(true))
.build()

def distinct(): Relation =
Relation
.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, CommandResult, Deduplicate, DeserializeToObject, Except, Intersect, LocalRelation, LogicalPlan, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint}
import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeserializeToObject, Except, Intersect, LocalRelation, LogicalPlan, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, UdfPacket}
Expand Down Expand Up @@ -90,6 +90,8 @@ class SparkConnectPlanner(val session: SparkSession) {
case proto.Relation.RelTypeCase.TAIL => transformTail(rel.getTail)
case proto.Relation.RelTypeCase.JOIN => transformJoin(rel.getJoin)
case proto.Relation.RelTypeCase.DEDUPLICATE => transformDeduplicate(rel.getDeduplicate)
case proto.Relation.RelTypeCase.DEDUPLICATE_WITHIN_WATERMARK =>
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Connected to the proto comment. We don't need this field.

transformDeduplicate(rel.getDeduplicateWithinWatermark, isWithinWatermark = true)
case proto.Relation.RelTypeCase.SET_OP => transformSetOperation(rel.getSetOp)
case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort)
case proto.Relation.RelTypeCase.DROP => transformDrop(rel.getDrop)
Expand Down Expand Up @@ -723,7 +725,8 @@ class SparkConnectPlanner(val session: SparkSession) {
CollectMetrics(rel.getName, metrics.map(_.named), transformRelation(rel.getInput))
}

private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
private def transformDeduplicate(rel: proto.Deduplicate,
isWithinWatermark: Boolean = false): LogicalPlan = {
if (!rel.hasInput) {
throw InvalidPlanInput("Deduplicate needs a plan input")
}
Expand All @@ -738,7 +741,8 @@ class SparkConnectPlanner(val session: SparkSession) {
val resolver = session.sessionState.analyzer.resolver
val allColumns = queryExecution.analyzed.output
if (rel.getAllColumnsAsKeys) {
Deduplicate(allColumns, queryExecution.analyzed)
if (isWithinWatermark) DeduplicateWithinWatermark(allColumns, queryExecution.analyzed)
else Deduplicate(allColumns, queryExecution.analyzed)
} else {
val toGroupColumnNames = rel.getColumnNamesList.asScala.toSeq
val groupCols = toGroupColumnNames.flatMap { (colName: String) =>
Expand All @@ -750,7 +754,8 @@ class SparkConnectPlanner(val session: SparkSession) {
}
cols
}
Deduplicate(groupCols, queryExecution.analyzed)
if (isWithinWatermark) DeduplicateWithinWatermark(groupCols, queryExecution.analyzed)
else Deduplicate(groupCols, queryExecution.analyzed)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,36 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
assert(e2.getMessage.contains("either deduplicate on all columns or a subset of columns"))
}

test("Test invalid deduplicateWithinWatermark") {
val deduplicateWithinWatermark = proto.Deduplicate
.newBuilder()
.setInput(readRel)
.setAllColumnsAsKeys(true)
.addColumnNames("test")
.setWithinWatermark(true)

val e = intercept[InvalidPlanInput] {
transform(
proto.Relation.newBuilder
.setDeduplicateWithinWatermark(deduplicateWithinWatermark)
.build())
}
assert(
e.getMessage.contains("Cannot deduplicate on both all columns and a subset of columns"))

val deduplicateWithinWatermark2 = proto.Deduplicate
.newBuilder()
.setInput(readRel)
.setWithinWatermark(true)
val e2 = intercept[InvalidPlanInput] {
transform(
proto.Relation.newBuilder
.setDeduplicateWithinWatermark(deduplicateWithinWatermark2)
.build())
}
assert(e2.getMessage.contains("either deduplicate on all columns or a subset of columns"))
}

test("Test invalid intersect, except") {
// Except with union_by_name=true
val except = proto.SetOperation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,16 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
comparePlans(connectPlan2, sparkPlan2)
}

test("Test basic deduplicateWithinWatermark") {
val connectPlan = connectTestRelation.distinct()
val sparkPlan = sparkTestRelation.distinct()
comparePlans(connectPlan, sparkPlan)

val connectPlan2 = connectTestRelation.deduplicateWithinWatermark(Seq("id", "name"))
val sparkPlan2 = sparkTestRelation.dropDuplicatesWithinWatermark(Seq("id", "name"))
comparePlans(connectPlan2, sparkPlan2)
}

test("Test union, except, intersect") {
val connectPlan1 = connectTestRelation.except(connectTestRelation, isAll = false)
val sparkPlan1 = sparkTestRelation.except(sparkTestRelation)
Expand Down
19 changes: 18 additions & 1 deletion python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,24 @@ def dropDuplicates(self, subset: Optional[List[str]] = None) -> "DataFrame":
drop_duplicates = dropDuplicates

def dropDuplicatesWithinWatermark(self, subset: Optional[List[str]] = None) -> "DataFrame":
raise NotImplementedError("dropDuplicatesWithinWatermark() is not implemented.")
if subset is not None and not isinstance(subset, (list, tuple)):
raise PySparkTypeError(
error_class="NOT_LIST_OR_TUPLE",
message_parameters={"arg_name": "subset", "arg_type": type(subset).__name__},
)

if subset is None:
return DataFrame.withPlan(
plan.DeduplicateWithinWatermark(child=self._plan, all_columns_as_keys=True), session=self._session
)
else:
return DataFrame.withPlan(
plan.DeduplicateWithinWatermark(child=self._plan, column_names=subset), session=self._session
)

dropDuplicatesWithinWatermark.__doc__ = PySparkDataFrame.dropDuplicatesWithinWatermark.__doc__

drop_duplicates_within_watermark = dropDuplicatesWithinWatermark

def distinct(self) -> "DataFrame":
return DataFrame.withPlan(
Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,27 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
return plan


class DeduplicateWithinWatermark(LogicalPlan):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't ned this anymore, right? we can add the flag to Deduplicate above to match the rest of the PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I noticed that as well, doing the change now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

def __init__(
self,
child: Optional["LogicalPlan"],
all_columns_as_keys: bool = False,
column_names: Optional[List[str]] = None,
) -> None:
super().__init__(child)
self.all_columns_as_keys = all_columns_as_keys
self.column_names = column_names

def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.deduplicate_within_watermark.input.CopyFrom(self._child.plan(session))
plan.deduplicate_within_watermark.all_columns_as_keys = self.all_columns_as_keys
if self.column_names is not None:
plan.deduplicate_within_watermark.column_names.extend(self.column_names)
return plan


class Sort(LogicalPlan):
def __init__(
self,
Expand Down
1 change: 0 additions & 1 deletion python/pyspark/sql/connect/proto/base_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,6 @@

_SPARKCONNECTSERVICE = DESCRIPTOR.services_by_name["SparkConnectService"]
if _descriptor._USE_C_DESCRIPTORS == False:

DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._options = None
Expand Down
Loading