-
Notifications
You must be signed in to change notification settings - Fork 175
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cdcd749
commit a5c1596
Showing
10 changed files
with
323 additions
and
3 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
use std::{collections::HashMap, future::ready}; | ||
|
||
use common_daft_config::DaftExecutionConfig; | ||
use common_file_formats::FileFormat; | ||
use eyre::{bail, WrapErr}; | ||
use futures::stream; | ||
use spark_connect::{ | ||
write_operation::{SaveMode, SaveType}, | ||
ExecutePlanResponse, Relation, WriteOperation, | ||
}; | ||
use tokio_util::sync::CancellationToken; | ||
use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status}; | ||
use tracing::warn; | ||
|
||
use crate::{ | ||
invalid_argument_err, | ||
op::execute::{ExecuteStream, PlanIds}, | ||
session::Session, | ||
translation, | ||
}; | ||
|
||
impl Session { | ||
pub async fn handle_write_command( | ||
&self, | ||
operation: WriteOperation, | ||
operation_id: String, | ||
) -> Result<ExecuteStream, Status> { | ||
use futures::{StreamExt, TryStreamExt}; | ||
|
||
let context = PlanIds { | ||
session: self.client_side_session_id().to_string(), | ||
server_side_session: self.server_side_session_id().to_string(), | ||
operation: operation_id, | ||
}; | ||
|
||
let finished = context.finished(); | ||
|
||
// operation: WriteOperation { | ||
// input: Some( | ||
// Relation { | ||
// common: Some( | ||
// RelationCommon { | ||
// source_info: "", | ||
// plan_id: Some( | ||
// 0, | ||
// ), | ||
// origin: None, | ||
// }, | ||
// ), | ||
// rel_type: Some( | ||
// Range( | ||
// Range { | ||
// start: Some( | ||
// 0, | ||
// ), | ||
// end: 10, | ||
// step: 1, | ||
// num_partitions: None, | ||
// }, | ||
// ), | ||
// ), | ||
// }, | ||
// ), | ||
// source: Some( | ||
// "parquet", | ||
// ), | ||
// mode: Unspecified, | ||
// sort_column_names: [], | ||
// partitioning_columns: [], | ||
// bucket_by: None, | ||
// options: {}, | ||
// clustering_columns: [], | ||
// save_type: Some( | ||
// Path( | ||
// "/var/folders/zy/g1zccty96bg_frmz9x0198zh0000gn/T/tmpxki7yyr0/test.parquet", | ||
// ), | ||
// ), | ||
// } | ||
|
||
let (tx, rx) = tokio::sync::mpsc::channel::<eyre::Result<ExecutePlanResponse>>(16); | ||
std::thread::spawn(move || { | ||
let result = (|| -> eyre::Result<()> { | ||
let WriteOperation { | ||
input, | ||
source, | ||
mode, | ||
sort_column_names, | ||
partitioning_columns, | ||
bucket_by, | ||
options, | ||
clustering_columns, | ||
save_type, | ||
} = operation; | ||
|
||
let Some(input) = input else { | ||
bail!("Input is required"); | ||
}; | ||
|
||
let Some(source) = source else { | ||
bail!("Source is required"); | ||
}; | ||
|
||
if source != "parquet" { | ||
bail!("Unsupported source: {source}; only parquet is supported"); | ||
} | ||
|
||
let Ok(mode) = SaveMode::try_from(mode) else { | ||
bail!("Invalid save mode: {mode}"); | ||
}; | ||
|
||
if !sort_column_names.is_empty() { | ||
// todo(completeness): implement sort | ||
warn!( | ||
"Ignoring sort_column_names: {sort_column_names:?} (not yet implemented)" | ||
); | ||
} | ||
|
||
if !partitioning_columns.is_empty() { | ||
// todo(completeness): implement partitioning | ||
warn!("Ignoring partitioning_columns: {partitioning_columns:?} (not yet implemented)"); | ||
} | ||
|
||
if let Some(bucket_by) = bucket_by { | ||
// todo(completeness): implement bucketing | ||
warn!("Ignoring bucket_by: {bucket_by:?} (not yet implemented)"); | ||
} | ||
|
||
if !options.is_empty() { | ||
// todo(completeness): implement options | ||
warn!("Ignoring options: {options:?} (not yet implemented)"); | ||
} | ||
|
||
if !clustering_columns.is_empty() { | ||
// todo(completeness): implement clustering | ||
warn!( | ||
"Ignoring clustering_columns: {clustering_columns:?} (not yet implemented)" | ||
); | ||
} | ||
|
||
match mode { | ||
SaveMode::Unspecified => {} | ||
SaveMode::Append => {} | ||
SaveMode::Overwrite => {} | ||
SaveMode::ErrorIfExists => {} | ||
SaveMode::Ignore => {} | ||
} | ||
|
||
let Some(save_type) = save_type else { | ||
return bail!("Save type is required"); | ||
}; | ||
|
||
let path = match save_type { | ||
SaveType::Path(path) => path, | ||
SaveType::Table(table) => { | ||
let name = table.table_name; | ||
bail!("Tried to write to table {name} but it is not yet implemented. Try to write to a path instead."); | ||
} | ||
}; | ||
|
||
let plan = translation::to_logical_plan(input)?; | ||
|
||
let plan = plan | ||
.table_write(&path, FileFormat::Parquet, None, None, None) | ||
.wrap_err("Failed to create table write plan")?; | ||
|
||
let logical_plan = plan.build(); | ||
let physical_plan = daft_local_plan::translate(&logical_plan)?; | ||
|
||
let cfg = DaftExecutionConfig::default(); | ||
|
||
// "hot" flow not a "cold" flow | ||
let iterator = daft_local_execution::run_local( | ||
&physical_plan, | ||
HashMap::new(), | ||
cfg.into(), | ||
None, | ||
CancellationToken::new(), // todo: maybe implement cancelling | ||
)?; | ||
|
||
for _ignored in iterator { | ||
|
||
} | ||
|
||
// this is so we make sure the operation is actually done | ||
// before we return | ||
// | ||
// an example where this is important is if we write to a parquet file | ||
// and then read immediately after, we need to wait for the write to finish | ||
|
||
Ok(()) | ||
})(); | ||
|
||
if let Err(e) = result { | ||
tx.blocking_send(Err(e)).unwrap(); | ||
} | ||
}); | ||
|
||
let stream = ReceiverStream::new(rx); | ||
|
||
let stream = stream | ||
.map_err(|e| Status::internal(format!("Error in Daft server: {e:?}"))) | ||
.chain(stream::once(ready(Ok(finished)))); | ||
|
||
Ok(Box::pin(stream)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
use daft_logical_plan::LogicalPlanBuilder; | ||
use eyre::{bail, WrapErr}; | ||
use spark_connect::read::ReadType; | ||
use tracing::warn; | ||
|
||
mod data_source; | ||
|
||
pub fn read(read: spark_connect::Read) -> eyre::Result<LogicalPlanBuilder> { | ||
let spark_connect::Read { | ||
is_streaming, | ||
read_type, | ||
} = read; | ||
|
||
warn!("Ignoring is_streaming: {is_streaming}"); | ||
|
||
let Some(read_type) = read_type else { | ||
bail!("Read type is required"); | ||
}; | ||
|
||
match read_type { | ||
ReadType::NamedTable(table) => { | ||
let name = table.unparsed_identifier; | ||
bail!("Tried to read from table {name} but it is not yet implemented. Try to read from a path instead."); | ||
} | ||
ReadType::DataSource(source) => { | ||
data_source::data_source(source).wrap_err("Failed to create data source") | ||
} | ||
} | ||
} |
42 changes: 42 additions & 0 deletions
42
src/daft-connect/src/translation/logical_plan/read/data_source.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
use daft_logical_plan::LogicalPlanBuilder; | ||
use daft_scan::builder::ParquetScanBuilder; | ||
use eyre::{bail, ensure, WrapErr}; | ||
use tracing::warn; | ||
|
||
pub fn data_source(data_source: spark_connect::read::DataSource) -> eyre::Result<LogicalPlanBuilder> { | ||
let spark_connect::read::DataSource { | ||
format, | ||
schema, | ||
options, | ||
paths, | ||
predicates, | ||
} = data_source; | ||
|
||
let Some(format) = format else { | ||
bail!("Format is required"); | ||
}; | ||
|
||
if format != "parquet" { | ||
bail!("Unsupported format: {format}; only parquet is supported"); | ||
} | ||
|
||
ensure!(!paths.is_empty(), "Paths are required"); | ||
|
||
if let Some(schema) = schema { | ||
warn!("Ignoring schema: {schema:?}; not yet implemented"); | ||
} | ||
|
||
if !options.is_empty() { | ||
warn!("Ignoring options: {options:?}; not yet implemented"); | ||
} | ||
|
||
if !predicates.is_empty() { | ||
warn!("Ignoring predicates: {predicates:?}; not yet implemented"); | ||
} | ||
|
||
let builder = ParquetScanBuilder::new(paths) | ||
.finish() | ||
.wrap_err("Failed to create parquet scan builder")?; | ||
|
||
Ok(builder) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from __future__ import annotations | ||
|
||
import tempfile | ||
import shutil | ||
import os | ||
|
||
|
||
def test_write_parquet(spark_session): | ||
# Create a temporary directory | ||
temp_dir = tempfile.mkdtemp() | ||
try: | ||
# Create DataFrame from range(10) | ||
df = spark_session.range(10) | ||
|
||
# Write DataFrame to parquet directory | ||
parquet_dir = os.path.join(temp_dir, "test.parquet") | ||
df.write.parquet(parquet_dir) | ||
|
||
# List all files in the parquet directory | ||
parquet_files = [f for f in os.listdir(parquet_dir) if f.endswith('.parquet')] | ||
print(f"Parquet files in directory: {parquet_files}") | ||
|
||
# Assert there is at least one parquet file | ||
assert len(parquet_files) > 0, "Expected at least one parquet file to be written" | ||
|
||
# Read back from the parquet directory (not specific file) | ||
df_read = spark_session.read.parquet(parquet_dir) | ||
|
||
# Verify the data is unchanged | ||
df_pandas = df.toPandas() | ||
df_read_pandas = df_read.toPandas() | ||
assert df_pandas["id"].equals(df_read_pandas["id"]), "Data should be unchanged after write/read" | ||
|
||
finally: | ||
# Clean up temp directory | ||
shutil.rmtree(temp_dir) |