Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(connect): read/write → csv, write → json #3361

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/daft-connect/src/op/execute/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ impl Session {
bail!("Source is required");
};

if source != "parquet" {
bail!("Unsupported source: {source}; only parquet is supported");
}
let file_format: FileFormat = source.parse()?;

let Ok(mode) = SaveMode::try_from(mode) else {
bail!("Invalid save mode: {mode}");
Expand Down Expand Up @@ -115,7 +113,7 @@ impl Session {
let plan = translator.to_logical_plan(input).await?;

let plan = plan
.table_write(&path, FileFormat::Parquet, None, None, None)
.table_write(&path, file_format, None, None, None)
.wrap_err("Failed to create table write plan")?;

let optimized_plan = plan.optimize()?;
Expand Down
29 changes: 19 additions & 10 deletions src/daft-connect/src/translation/logical_plan/read/data_source.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use daft_logical_plan::LogicalPlanBuilder;
use daft_scan::builder::ParquetScanBuilder;
use daft_scan::builder::{CsvScanBuilder, ParquetScanBuilder};
use eyre::{bail, ensure, WrapErr};
use tracing::warn;

Expand All @@ -18,10 +18,6 @@
bail!("Format is required");
};

if format != "parquet" {
bail!("Unsupported format: {format}; only parquet is supported");
}

ensure!(!paths.is_empty(), "Paths are required");

if let Some(schema) = schema {
Expand All @@ -36,10 +32,23 @@
warn!("Ignoring predicates: {predicates:?}; not yet implemented");
}

let builder = ParquetScanBuilder::new(paths)
.finish()
.await
.wrap_err("Failed to create parquet scan builder")?;
let plan = match &*format {
"parquet" => ParquetScanBuilder::new(paths)
.finish()
.await
.wrap_err("Failed to create parquet scan builder")?,
"csv" => CsvScanBuilder::new(paths)
.finish()
.await
.wrap_err("Failed to create csv scan builder")?,
"json" => {

Check warning on line 44 in src/daft-connect/src/translation/logical_plan/read/data_source.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/read/data_source.rs#L44

Added line #L44 was not covered by tests
// todo(completeness): implement json reading
bail!("json reading is not yet implemented");

Check warning on line 46 in src/daft-connect/src/translation/logical_plan/read/data_source.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/read/data_source.rs#L46

Added line #L46 was not covered by tests
}
other => {
bail!("Unsupported format: {other}; only parquet and csv are supported");

Check warning on line 49 in src/daft-connect/src/translation/logical_plan/read/data_source.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/read/data_source.rs#L48-L49

Added lines #L48 - L49 were not covered by tests
}
};

Ok(builder)
Ok(plan)
}
88 changes: 88 additions & 0 deletions tests/connect/test_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import annotations

import os

import pytest


def test_write_csv_basic(spark_session, tmp_path):
df = spark_session.range(10)
csv_dir = os.path.join(tmp_path, "csv")
df.write.csv(csv_dir)

csv_files = [f for f in os.listdir(csv_dir) if f.endswith(".csv")]
assert len(csv_files) > 0, "Expected at least one CSV file to be written"

df_read = spark_session.read.csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["id"].equals(df_read_pandas["id"]), "Data should be unchanged after write/read"


def test_write_csv_with_header(spark_session, tmp_path):
df = spark_session.range(10)
csv_dir = os.path.join(tmp_path, "csv")
df.write.option("header", True).csv(csv_dir)

df_read = spark_session.read.option("header", True).csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["id"].equals(df_read_pandas["id"])


def test_write_csv_with_delimiter(spark_session, tmp_path):
df = spark_session.range(10)
csv_dir = os.path.join(tmp_path, "csv")
df.write.option("sep", "|").csv(csv_dir)

df_read = spark_session.read.option("sep", "|").csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["id"].equals(df_read_pandas["id"])


def test_write_csv_with_quote(spark_session, tmp_path):
df = spark_session.createDataFrame([("a,b",), ("c'd",)], ["text"])
csv_dir = os.path.join(tmp_path, "csv")
df.write.option("quote", "'").csv(csv_dir)

df_read = spark_session.read.option("quote", "'").csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["text"].equals(df_read_pandas["text"])


def test_write_csv_with_escape(spark_session, tmp_path):
df = spark_session.createDataFrame([("a'b",), ("c'd",)], ["text"])
csv_dir = os.path.join(tmp_path, "csv")
df.write.option("escape", "\\").csv(csv_dir)

df_read = spark_session.read.option("escape", "\\").csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["text"].equals(df_read_pandas["text"])


@pytest.mark.skip(
reason="https://github.com/Eventual-Inc/Daft/issues/3609: CSV null value handling not yet implemented"
)
def test_write_csv_with_null_value(spark_session, tmp_path):
df = spark_session.createDataFrame([(1, None), (2, "test")], ["id", "value"])
csv_dir = os.path.join(tmp_path, "csv")
df.write.option("nullValue", "NULL").csv(csv_dir)

df_read = spark_session.read.option("nullValue", "NULL").csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["value"].isna().equals(df_read_pandas["value"].isna())


def test_write_csv_with_compression(spark_session, tmp_path):
df = spark_session.range(10)
csv_dir = os.path.join(tmp_path, "csv")
df.write.option("compression", "gzip").csv(csv_dir)

df_read = spark_session.read.csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["id"].equals(df_read_pandas["id"])
Loading