Skip to content

Commit

Permalink
[FEAT] connect: add drop support
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 11, 2024
1 parent 5238279 commit b85510e
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ use spark_connect::{relation::RelType, Limit, Relation};
use tracing::warn;

use crate::translation::logical_plan::{
aggregate::aggregate, filter::filter, local_relation::local_relation, project::project,
range::range, read::read, to_df::to_df, with_columns::with_columns,
aggregate::aggregate, drop::drop, filter::filter, local_relation::local_relation,
project::project, range::range, read::read, to_df::to_df, with_columns::with_columns,
};

mod aggregate;
mod drop;
mod filter;
mod local_relation;
mod project;
Expand Down Expand Up @@ -78,6 +79,9 @@ pub async fn to_logical_plan(relation: Relation) -> eyre::Result<Plan> {
RelType::Read(r) => read(r)
.await
.wrap_err("Failed to apply read to logical plan"),
RelType::Drop(d) => drop(*d)
.await
.wrap_err("Failed to apply drop to logical plan"),
plan => bail!("Unsupported relation type: {plan:?}"),
}
}
Expand Down
39 changes: 39 additions & 0 deletions src/daft-connect/src/translation/logical_plan/drop.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use eyre::bail;

use crate::translation::{to_logical_plan, Plan};

pub async fn drop(drop: spark_connect::Drop) -> eyre::Result<Plan> {
let spark_connect::Drop {
input,
columns,
column_names,
} = drop;

let Some(input) = input else {
bail!("input is required");
};

if !columns.is_empty() {
bail!("columns is not supported; use column_names instead");
}

let mut plan = Box::pin(to_logical_plan(*input)).await?;

// Get all column names from the schema
let all_columns = plan.builder.schema().names();

// Create a set of columns to drop for efficient lookup
let columns_to_drop: std::collections::HashSet<_> = column_names.iter().collect();

// Create expressions for all columns except the ones being dropped
let to_select = all_columns
.iter()
.filter(|col_name| !columns_to_drop.contains(*col_name))
.map(|col_name| daft_dsl::col(col_name.clone()))
.collect();

// Use select to keep only the columns we want
plan.builder = plan.builder.select(to_select)?;

Ok(plan)
}
17 changes: 17 additions & 0 deletions tests/connect/test_drop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations


def test_drop(spark_session):
# Create DataFrame from range(10)
df = spark_session.range(10)

# Drop the 'id' column
df_dropped = df.drop("id")

# Verify the drop was successful
assert "id" not in df_dropped.columns, "Column 'id' should be dropped"
assert len(df_dropped.columns) == len(df.columns) - 1, "Should have one less column after drop"

# Verify the DataFrame has no columns after dropping all columns"
assert len(df_dropped.toPandas().columns) == 0, "DataFrame should have no columns after dropping 'id'"
assert df_dropped.count() == df.count(), "Row count should be unchanged after drop"

0 comments on commit b85510e

Please sign in to comment.