From e8054870fb44d4ac3ab0794a7b3f411367d0988f Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Tue, 3 Dec 2024 01:01:25 +1100 Subject: [PATCH 01/20] refactor(rust): Replace custom `PushNode` trait with `Extend` (#20107) --- .../polars-plan/src/plans/aexpr/traverse.rs | 59 ++++++++----------- crates/polars-plan/src/plans/aexpr/utils.rs | 2 +- crates/polars-plan/src/plans/ir/inputs.rs | 19 ++---- crates/polars-plan/src/utils.rs | 28 --------- 4 files changed, 30 insertions(+), 78 deletions(-) diff --git a/crates/polars-plan/src/plans/aexpr/traverse.rs b/crates/polars-plan/src/plans/aexpr/traverse.rs index 20e7a454169c..3850cbafca6e 100644 --- a/crates/polars-plan/src/plans/aexpr/traverse.rs +++ b/crates/polars-plan/src/plans/aexpr/traverse.rs @@ -2,39 +2,35 @@ use super::*; impl AExpr { /// Push nodes at this level to a pre-allocated stack. - pub(crate) fn nodes(&self, container: &mut impl PushNode) { + pub(crate) fn nodes(&self, container: &mut E) + where + E: Extend, + { use AExpr::*; match self { Column(_) | Literal(_) | Len => {}, - Alias(e, _) => container.push_node(*e), + Alias(e, _) => container.extend([*e]), BinaryExpr { left, op: _, right } => { // reverse order so that left is popped first - container.push_node(*right); - container.push_node(*left); + container.extend([*right, *left]); }, - Cast { expr, .. } => container.push_node(*expr), - Sort { expr, .. } => container.push_node(*expr), + Cast { expr, .. } => container.extend([*expr]), + Sort { expr, .. } => container.extend([*expr]), Gather { expr, idx, .. } => { - container.push_node(*idx); - // latest, so that it is popped first - container.push_node(*expr); + container.extend([*idx, *expr]); }, SortBy { expr, by, .. } => { - for node in by { - container.push_node(*node) - } + container.extend(by.iter().cloned()); // latest, so that it is popped first - container.push_node(*expr); + container.extend([*expr]); }, Filter { input, by } => { - container.push_node(*by); - // latest, so that it is popped first - container.push_node(*input); + container.extend([*by, *input]); }, Agg(agg_e) => match agg_e.get_input() { - NodeInputs::Single(node) => container.push_node(node), - NodeInputs::Many(nodes) => container.extend_from_slice(&nodes), + NodeInputs::Single(node) => container.extend([node]), + NodeInputs::Many(nodes) => container.extend(nodes), NodeInputs::Leaf => {}, }, Ternary { @@ -42,21 +38,15 @@ impl AExpr { falsy, predicate, } => { - container.push_node(*predicate); - container.push_node(*falsy); - // latest, so that it is popped first - container.push_node(*truthy); + container.extend([*predicate, *falsy, *truthy]); }, AnonymousFunction { input, .. } | Function { input, .. } => // we iterate in reverse order, so that the lhs is popped first and will be found // as the root columns/ input columns by `_suffix` and `_keep_name` etc. { - input - .iter() - .rev() - .for_each(|e| container.push_node(e.node())) + container.extend(input.iter().rev().map(|e| e.node())) }, - Explode(e) => container.push_node(*e), + Explode(e) => container.extend([*e]), Window { function, partition_by, @@ -64,23 +54,20 @@ impl AExpr { options: _, } => { if let Some((n, _)) = order_by { - container.push_node(*n); - } - for e in partition_by.iter().rev() { - container.push_node(*e); + container.extend([*n]); } + + container.extend(partition_by.iter().rev().cloned()); + // latest so that it is popped first - container.push_node(*function); + container.extend([*function]); }, Slice { input, offset, length, } => { - container.push_node(*length); - container.push_node(*offset); - // latest so that it is popped first - container.push_node(*input); + container.extend([*length, *offset, *input]); }, } } diff --git a/crates/polars-plan/src/plans/aexpr/utils.rs b/crates/polars-plan/src/plans/aexpr/utils.rs index 92657acc5340..834aa4ff6a75 100644 --- a/crates/polars-plan/src/plans/aexpr/utils.rs +++ b/crates/polars-plan/src/plans/aexpr/utils.rs @@ -26,7 +26,7 @@ pub fn is_elementwise(stack: &mut UnitVec, ae: &AExpr, expr_arena: &Arena< let rhs = rhs.node(); if matches!(expr_arena.get(rhs), AExpr::Literal { .. }) { - stack.push_node(input[0].node()); + stack.extend([input[0].node()]); return; } }; diff --git a/crates/polars-plan/src/plans/ir/inputs.rs b/crates/polars-plan/src/plans/ir/inputs.rs index 2a7c14e300de..865d398959ff 100644 --- a/crates/polars-plan/src/plans/ir/inputs.rs +++ b/crates/polars-plan/src/plans/ir/inputs.rs @@ -214,20 +214,16 @@ impl IR { /// or an in-memory DataFrame has none. A Union has multiple. pub fn copy_inputs(&self, container: &mut T) where - T: PushNode, + T: Extend, { use IR::*; let input = match self { Union { inputs, .. } => { - for node in inputs { - container.push_node(*node); - } + container.extend(inputs.iter().cloned()); return; }, HConcat { inputs, .. } => { - for node in inputs { - container.push_node(*node); - } + container.extend(inputs.iter().cloned()); return; }, Slice { input, .. } => *input, @@ -243,8 +239,7 @@ impl IR { input_right, .. } => { - container.push_node(*input_left); - container.push_node(*input_right); + container.extend([*input_left, *input_right]); return; }, HStack { input, .. } => *input, @@ -254,9 +249,7 @@ impl IR { ExtContext { input, contexts, .. } => { - for n in contexts { - container.push_node(*n) - } + container.extend(contexts.iter().cloned()); *input }, Scan { .. } => return, @@ -265,7 +258,7 @@ impl IR { PythonScan { .. } => return, Invalid => unreachable!(), }; - container.push_node(input) + container.extend([input]) } pub fn get_inputs(&self) -> UnitVec { diff --git a/crates/polars-plan/src/utils.rs b/crates/polars-plan/src/utils.rs index fdf9c979738a..0d433e9e74d7 100644 --- a/crates/polars-plan/src/utils.rs +++ b/crates/polars-plan/src/utils.rs @@ -2,7 +2,6 @@ use std::fmt::Formatter; use std::iter::FlatMap; use polars_core::prelude::*; -use polars_utils::idx_vec::UnitVec; use crate::constants::get_len_name; use crate::prelude::*; @@ -40,33 +39,6 @@ pub(crate) fn fmt_column_delimited>( write!(f, "{container_end}") } -// TODO: Remove this and use `Extend` instead. -pub trait PushNode { - fn push_node(&mut self, value: Node); - - fn extend_from_slice(&mut self, values: &[Node]); -} - -impl PushNode for Vec { - fn push_node(&mut self, value: Node) { - self.push(value) - } - - fn extend_from_slice(&mut self, values: &[Node]) { - Vec::extend_from_slice(self, values) - } -} - -impl PushNode for UnitVec { - fn push_node(&mut self, value: Node) { - self.push(value) - } - - fn extend_from_slice(&mut self, values: &[Node]) { - UnitVec::extend(self, values.iter().copied()) - } -} - pub(crate) fn is_scan(plan: &IR) -> bool { matches!(plan, IR::Scan { .. } | IR::DataFrameScan { .. }) } From 4c1c51c9d7131e0c8ac290390501dd8af8df165a Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Mon, 2 Dec 2024 20:12:03 +0400 Subject: [PATCH 02/20] build: Upgrade `sqlparser-rs` from version 0.49 to 0.52 (#20110) --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- crates/polars-sql/src/context.rs | 9 ++++++--- crates/polars-sql/src/sql_expr.rs | 27 ++++++++++++++++++++++----- 4 files changed, 31 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3c24e9539f40..c2061e769593 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4495,9 +4495,9 @@ dependencies = [ [[package]] name = "sqlparser" -version = "0.49.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a404d0e14905361b918cb8afdb73605e25c1d5029312bd9785142dcb3aa49e" +checksum = "9a875d8cd437cc8a97e9aeaeea352ec9a19aea99c23e9effb17757291de80b08" dependencies = [ "log", ] diff --git a/Cargo.toml b/Cargo.toml index 6a0366668b5f..4249c5d4b494 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,7 +77,7 @@ serde_json = "1" simd-json = { version = "0.14", features = ["known-key"] } simdutf8 = "0.1.4" slotmap = "1" -sqlparser = "0.49" +sqlparser = "0.52" stacker = "0.1" streaming-iterator = "0.1.9" strength_reduce = "0.2" diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 8b17645b3fb7..3ec7eb3e243a 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -501,14 +501,17 @@ impl SQLContext { fn execute_truncate_table(&mut self, stmt: &Statement) -> PolarsResult { if let Statement::Truncate { - table_name, + table_names, partitions, .. } = stmt { match partitions { None => { - let tbl = table_name.to_string(); + if table_names.len() != 1 { + polars_bail!(SQLInterface: "TRUNCATE expects exactly one table name; found {}", table_names.len()) + } + let tbl = table_names[0].to_string(); if let Some(lf) = self.table_map.get_mut(&tbl) { *lf = DataFrame::empty_with_schema( lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena) @@ -971,7 +974,7 @@ impl SQLContext { name, alias, args, .. } => { if let Some(args) = args { - return self.execute_table_function(name, alias, args); + return self.execute_table_function(name, alias, &args.args); } let tbl_name = name.0.first().unwrap().value.as_str(); if let Some(lf) = self.get_table_from_current_scope(tbl_name) { diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index a2ada46e1c68..4f0f44b826fb 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -93,6 +93,7 @@ impl SQLExprVisitor<'_> { left, compare_op, right, + is_some: _, } => self.visit_any(left, compare_op, right), SQLExpr::Array(arr) => self.visit_array_expr(&arr.elem, true, None), SQLExpr::Between { @@ -110,9 +111,11 @@ impl SQLExprVisitor<'_> { } => self.visit_cast(expr, data_type, format, kind), SQLExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()), SQLExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents), - SQLExpr::Extract { field, expr } => { - parse_extract_date_part(self.visit_expr(expr)?, field) - }, + SQLExpr::Extract { + field, + syntax: _, + expr, + } => parse_extract_date_part(self.visit_expr(expr)?, field), SQLExpr::Floor { expr, .. } => Ok(self.visit_expr(expr)?.floor()), SQLExpr::Function(function) => self.visit_function(function), SQLExpr::Identifier(ident) => self.visit_identifier(ident), @@ -146,16 +149,28 @@ impl SQLExprVisitor<'_> { SQLExpr::IsTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true))), SQLExpr::Like { negated, + any, expr, pattern, escape_char, - } => self.visit_like(*negated, expr, pattern, escape_char, false), + } => { + if *any { + polars_bail!(SQLSyntax: "LIKE ANY is not a supported syntax") + } + self.visit_like(*negated, expr, pattern, escape_char, false) + }, SQLExpr::ILike { negated, + any, expr, pattern, escape_char, - } => self.visit_like(*negated, expr, pattern, escape_char, true), + } => { + if *any { + polars_bail!(SQLSyntax: "ILIKE ANY is not a supported syntax") + } + self.visit_like(*negated, expr, pattern, escape_char, true) + }, SQLExpr::Nested(expr) => self.visit_expr(expr), SQLExpr::Position { expr, r#in } => Ok( // note: SQL is 1-indexed @@ -537,6 +552,7 @@ impl SQLExprVisitor<'_> { ) { SQLExpr::Like { negated: matches!(op, SQLBinaryOperator::PGNotLikeMatch), + any: false, expr: Box::new(left.clone()), pattern: Box::new(right.clone()), escape_char: None, @@ -544,6 +560,7 @@ impl SQLExprVisitor<'_> { } else { SQLExpr::ILike { negated: matches!(op, SQLBinaryOperator::PGNotILikeMatch), + any: false, expr: Box::new(left.clone()), pattern: Box::new(right.clone()), escape_char: None, From e6a3d37bdcdcfd3573d1bd0354f4d0e32e03981a Mon Sep 17 00:00:00 2001 From: Luke Manley Date: Tue, 3 Dec 2024 01:56:41 -0500 Subject: [PATCH 03/20] fix: Materialize smallest dyn ints to use feature gate for i8/i16 (#20108) --- crates/polars-core/src/utils/supertype.rs | 52 ++++++++++++----------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/crates/polars-core/src/utils/supertype.rs b/crates/polars-core/src/utils/supertype.rs index 4e30a8da7904..fad3299d5891 100644 --- a/crates/polars-core/src/utils/supertype.rs +++ b/crates/polars-core/src/utils/supertype.rs @@ -465,35 +465,39 @@ pub fn materialize_dyn_int(v: i128) -> AnyValue<'static> { fn materialize_dyn_int_pos(v: i128) -> AnyValue<'static> { // Try to get the "smallest" fitting value. // TODO! next breaking go to true smallest. - match u8::try_from(v).ok() { - Some(v) => AnyValue::UInt8(v), - None => match u16::try_from(v).ok() { - Some(v) => AnyValue::UInt16(v), - None => match u32::try_from(v).ok() { - Some(v) => AnyValue::UInt32(v), - None => match u64::try_from(v).ok() { - Some(v) => AnyValue::UInt64(v), - None => AnyValue::Null, - }, - }, + #[cfg(feature = "dtype-u8")] + if let Ok(v) = u8::try_from(v) { + return AnyValue::UInt8(v); + } + #[cfg(feature = "dtype-u16")] + if let Ok(v) = u16::try_from(v) { + return AnyValue::UInt16(v); + } + match u32::try_from(v).ok() { + Some(v) => AnyValue::UInt32(v), + None => match u64::try_from(v).ok() { + Some(v) => AnyValue::UInt64(v), + None => AnyValue::Null, }, } } fn materialize_smallest_dyn_int(v: i128) -> AnyValue<'static> { - match i8::try_from(v).ok() { - Some(v) => AnyValue::Int8(v), - None => match i16::try_from(v).ok() { - Some(v) => AnyValue::Int16(v), - None => match i32::try_from(v).ok() { - Some(v) => AnyValue::Int32(v), - None => match i64::try_from(v).ok() { - Some(v) => AnyValue::Int64(v), - None => match u64::try_from(v).ok() { - Some(v) => AnyValue::UInt64(v), - None => AnyValue::Null, - }, - }, + #[cfg(feature = "dtype-i8")] + if let Ok(v) = i8::try_from(v) { + return AnyValue::Int8(v); + } + #[cfg(feature = "dtype-i16")] + if let Ok(v) = i16::try_from(v) { + return AnyValue::Int16(v); + } + match i32::try_from(v).ok() { + Some(v) => AnyValue::Int32(v), + None => match i64::try_from(v).ok() { + Some(v) => AnyValue::Int64(v), + None => match u64::try_from(v).ok() { + Some(v) => AnyValue::UInt64(v), + None => AnyValue::Null, }, }, } From ec52966ddcd478d8e49243a0d796bb34ed35dc26 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Tue, 3 Dec 2024 07:57:24 +0100 Subject: [PATCH 04/20] feat(python): Enable view arrow export in `write_delta` (#20092) --- py-polars/polars/dataframe/frame.py | 8 ++++++-- py-polars/pyproject.toml | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 649b1b15120e..92b10e81ca45 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -4283,13 +4283,18 @@ def write_delta( _check_if_delta_available() from deltalake import DeltaTable, write_deltalake + from deltalake import __version__ as delta_version + from packaging.version import Version _check_for_unsupported_types(self.dtypes) if isinstance(target, (str, Path)): target = _resolve_delta_lake_uri(str(target), strict=False) - data = self.to_arrow() + if Version(delta_version) >= Version("0.22.3"): + data = self.to_arrow(compat_level=CompatLevel.newest()) + else: + data = self.to_arrow() if mode == "merge": if delta_merge_options is None: @@ -4316,7 +4321,6 @@ def write_delta( schema=schema, mode=mode, storage_options=storage_options, - large_dtypes=True, **delta_write_options, ) return None diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index 63cf77976175..fa060fb48810 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -61,7 +61,7 @@ database = ["polars[adbc,connectorx,sqlalchemy]", "nest-asyncio"] fsspec = ["fsspec"] # Other I/O -deltalake = ["deltalake >= 0.15.0"] +deltalake = ["deltalake >= 0.19.0"] iceberg = ["pyiceberg >= 0.5.0"] # Other From 02638688860514db9730a5dc206af88056060c28 Mon Sep 17 00:00:00 2001 From: Luka Peschke Date: Tue, 3 Dec 2024 07:58:10 +0100 Subject: [PATCH 05/20] feat(rust): Allow setting and reading custom schema-level IPC metadata (#20066) Signed-off-by: Luka Peschke --- crates/polars-arrow/src/io/ipc/append/mod.rs | 1 + crates/polars-arrow/src/io/ipc/read/file.rs | 8 +++-- crates/polars-arrow/src/io/ipc/read/schema.rs | 30 +++++++++++++++-- crates/polars-arrow/src/io/ipc/read/stream.rs | 5 ++- .../polars-arrow/src/io/ipc/write/schema.rs | 17 ++++++++-- .../polars-arrow/src/io/ipc/write/stream.rs | 15 ++++++++- .../polars-arrow/src/io/ipc/write/writer.rs | 22 +++++++++++-- crates/polars-io/src/ipc/ipc_file.rs | 12 ++++++- crates/polars-io/src/ipc/ipc_stream.rs | 32 +++++++++++++++++-- crates/polars-io/src/ipc/write.rs | 24 ++++++++++++-- .../polars-parquet/src/arrow/write/schema.rs | 4 +-- 11 files changed, 151 insertions(+), 19 deletions(-) diff --git a/crates/polars-arrow/src/io/ipc/append/mod.rs b/crates/polars-arrow/src/io/ipc/append/mod.rs index 340e8cea2a9b..446243592bc2 100644 --- a/crates/polars-arrow/src/io/ipc/append/mod.rs +++ b/crates/polars-arrow/src/io/ipc/append/mod.rs @@ -66,6 +66,7 @@ impl FileWriter { cannot_replace: true, }, encoded_message: Default::default(), + custom_schema_metadata: None, }) } } diff --git a/crates/polars-arrow/src/io/ipc/read/file.rs b/crates/polars-arrow/src/io/ipc/read/file.rs index e75fae36730e..d06a59b3ff9d 100644 --- a/crates/polars-arrow/src/io/ipc/read/file.rs +++ b/crates/polars-arrow/src/io/ipc/read/file.rs @@ -11,7 +11,7 @@ use super::common::*; use super::schema::fb_to_schema; use super::{Dictionaries, OutOfSpecKind, SendableIterator}; use crate::array::Array; -use crate::datatypes::ArrowSchemaRef; +use crate::datatypes::{ArrowSchemaRef, Metadata}; use crate::io::ipc::IpcSchema; use crate::record_batch::RecordBatchT; @@ -21,6 +21,9 @@ pub struct FileMetadata { /// The schema that is read from the file footer pub schema: ArrowSchemaRef, + /// The custom metadata that is read from the schema + pub custom_schema_metadata: Option>, + /// The files' [`IpcSchema`] pub ipc_schema: IpcSchema, @@ -245,7 +248,7 @@ pub fn deserialize_footer(footer_data: &[u8], size: u64) -> PolarsResult>>()) .transpose()?; let ipc_schema = deserialize_schema_ref_from_footer(footer)?; - let (schema, ipc_schema) = fb_to_schema(ipc_schema)?; + let (schema, ipc_schema, custom_schema_metadata) = fb_to_schema(ipc_schema)?; Ok(FileMetadata { schema: Arc::new(schema), @@ -253,6 +256,7 @@ pub fn deserialize_footer(footer_data: &[u8], size: u64) -> PolarsResult PolarsResult<(ArrowSchema, IpcSchema)> { +pub fn deserialize_schema( + message: &[u8], +) -> PolarsResult<(ArrowSchema, IpcSchema, Option)> { let message = arrow_format::ipc::MessageRef::read_as_root(message) .map_err(|_err| polars_err!(oos = "Unable deserialize message: {err:?}"))?; @@ -374,7 +376,7 @@ pub fn deserialize_schema(message: &[u8]) -> PolarsResult<(ArrowSchema, IpcSchem /// Deserialize the raw Schema table from IPC format to Schema data type pub(super) fn fb_to_schema( schema: arrow_format::ipc::SchemaRef, -) -> PolarsResult<(ArrowSchema, IpcSchema)> { +) -> PolarsResult<(ArrowSchema, IpcSchema, Option)> { let fields = schema .fields()? .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingFields))?; @@ -393,12 +395,33 @@ pub(super) fn fb_to_schema( arrow_format::ipc::Endianness::Big => false, }; + let custom_schema_metadata = match schema.custom_metadata()? { + None => None, + Some(metadata) => { + let metadata: Metadata = metadata + .into_iter() + .filter_map(|kv_result| { + // FIXME: silently hiding errors here + let kv_ref = kv_result.ok()?; + Some((kv_ref.key().ok()??.into(), kv_ref.value().ok()??.into())) + }) + .collect(); + + if metadata.is_empty() { + None + } else { + Some(metadata) + } + }, + }; + Ok(( arrow_schema, IpcSchema { fields: ipc_fields, is_little_endian, }, + custom_schema_metadata, )) } @@ -415,11 +438,12 @@ pub(super) fn deserialize_stream_metadata(meta: &[u8]) -> PolarsResult, + /// The IPC version of the stream pub version: arrow_format::ipc::MetadataVersion, diff --git a/crates/polars-arrow/src/io/ipc/write/schema.rs b/crates/polars-arrow/src/io/ipc/write/schema.rs index e8ed25c5c77e..76caf6243428 100644 --- a/crates/polars-arrow/src/io/ipc/write/schema.rs +++ b/crates/polars-arrow/src/io/ipc/write/schema.rs @@ -7,8 +7,12 @@ use crate::datatypes::{ use crate::io::ipc::endianness::is_native_little_endian; /// Converts a [ArrowSchema] and [IpcField]s to a flatbuffers-encoded [arrow_format::ipc::Message]. -pub fn schema_to_bytes(schema: &ArrowSchema, ipc_fields: &[IpcField]) -> Vec { - let schema = serialize_schema(schema, ipc_fields); +pub fn schema_to_bytes( + schema: &ArrowSchema, + ipc_fields: &[IpcField], + custom_metadata: Option<&Metadata>, +) -> Vec { + let schema = serialize_schema(schema, ipc_fields, custom_metadata); let message = arrow_format::ipc::Message { version: arrow_format::ipc::MetadataVersion::V5, @@ -24,6 +28,7 @@ pub fn schema_to_bytes(schema: &ArrowSchema, ipc_fields: &[IpcField]) -> Vec pub fn serialize_schema( schema: &ArrowSchema, ipc_fields: &[IpcField], + custom_schema_metadata: Option<&Metadata>, ) -> arrow_format::ipc::Schema { let endianness = if is_native_little_endian() { arrow_format::ipc::Endianness::Little @@ -37,7 +42,13 @@ pub fn serialize_schema( .map(|(field, ipc_field)| serialize_field(field, ipc_field)) .collect::>(); - let custom_metadata = None; + let custom_metadata = custom_schema_metadata.and_then(|custom_meta| { + let as_kv = custom_meta + .iter() + .map(|(key, val)| key_value(key.clone().into_string(), val.clone().into_string())) + .collect::>(); + (!as_kv.is_empty()).then_some(as_kv) + }); arrow_format::ipc::Schema { endianness, diff --git a/crates/polars-arrow/src/io/ipc/write/stream.rs b/crates/polars-arrow/src/io/ipc/write/stream.rs index 330b35d4ca4b..b6191e45f902 100644 --- a/crates/polars-arrow/src/io/ipc/write/stream.rs +++ b/crates/polars-arrow/src/io/ipc/write/stream.rs @@ -4,6 +4,7 @@ //! however the `FileWriter` expects a reader that supports `Seek`ing use std::io::Write; +use std::sync::Arc; use polars_error::{PolarsError, PolarsResult}; @@ -30,6 +31,8 @@ pub struct StreamWriter { finished: bool, /// Keeps track of dictionaries that have been written dictionary_tracker: DictionaryTracker, + /// Custom schema-level metadata + custom_schema_metadata: Option>, ipc_fields: Option>, } @@ -46,9 +49,15 @@ impl StreamWriter { cannot_replace: false, }, ipc_fields: None, + custom_schema_metadata: None, } } + /// Sets custom schema metadata. Must be called before `start` is called + pub fn set_custom_schema_metadata(&mut self, custom_metadata: Arc) { + self.custom_schema_metadata = Some(custom_metadata); + } + /// Starts the stream by writing a Schema message to it. /// Use `ipc_fields` to declare dictionary ids in the schema, for dictionary-reuse pub fn start( @@ -63,7 +72,11 @@ impl StreamWriter { }); let encoded_message = EncodedData { - ipc_message: schema_to_bytes(schema, self.ipc_fields.as_ref().unwrap()), + ipc_message: schema_to_bytes( + schema, + self.ipc_fields.as_ref().unwrap(), + self.custom_schema_metadata.as_deref(), + ), arrow_data: vec![], }; write_message(&mut self.writer, &encoded_message)?; diff --git a/crates/polars-arrow/src/io/ipc/write/writer.rs b/crates/polars-arrow/src/io/ipc/write/writer.rs index d709f5e8a195..b95064cef8b3 100644 --- a/crates/polars-arrow/src/io/ipc/write/writer.rs +++ b/crates/polars-arrow/src/io/ipc/write/writer.rs @@ -1,4 +1,5 @@ use std::io::Write; +use std::sync::Arc; use arrow_format::ipc::planus::Builder; use polars_error::{polars_bail, PolarsResult}; @@ -40,6 +41,8 @@ pub struct FileWriter { pub(crate) dictionary_tracker: DictionaryTracker, /// Buffer/scratch that is reused between writes pub(crate) encoded_message: EncodedData, + /// Custom schema-level metadata + pub(crate) custom_schema_metadata: Option>, } impl FileWriter { @@ -83,6 +86,7 @@ impl FileWriter { cannot_replace: true, }, encoded_message: Default::default(), + custom_schema_metadata: None, } } @@ -116,7 +120,12 @@ impl FileWriter { // write the schema, set the written bytes to the schema let encoded_message = EncodedData { - ipc_message: schema_to_bytes(&self.schema, &self.ipc_fields), + ipc_message: schema_to_bytes( + &self.schema, + &self.ipc_fields, + // No need to pass metadata here, as it is already written to the footer in `finish` + None, + ), arrow_data: vec![], }; @@ -210,7 +219,11 @@ impl FileWriter { // write EOS write_continuation(&mut self.writer, 0)?; - let schema = schema::serialize_schema(&self.schema, &self.ipc_fields); + let schema = schema::serialize_schema( + &self.schema, + &self.ipc_fields, + self.custom_schema_metadata.as_deref(), + ); let root = arrow_format::ipc::Footer { version: arrow_format::ipc::MetadataVersion::V5, @@ -230,4 +243,9 @@ impl FileWriter { Ok(()) } + + /// Sets custom schema metadata. Must be called before `start` is called + pub fn set_custom_schema_metadata(&mut self, custom_metadata: Arc) { + self.custom_schema_metadata = Some(custom_metadata); + } } diff --git a/crates/polars-io/src/ipc/ipc_file.rs b/crates/polars-io/src/ipc/ipc_file.rs index 100d37b2c941..81925bbe180c 100644 --- a/crates/polars-io/src/ipc/ipc_file.rs +++ b/crates/polars-io/src/ipc/ipc_file.rs @@ -35,7 +35,7 @@ use std::io::{Read, Seek}; use std::path::PathBuf; -use arrow::datatypes::ArrowSchemaRef; +use arrow::datatypes::{ArrowSchemaRef, Metadata}; use arrow::io::ipc::read::{self, get_row_count}; use arrow::record_batch::RecordBatch; use polars_core::prelude::*; @@ -115,6 +115,16 @@ impl IpcReader { self.get_metadata()?; Ok(self.schema.as_ref().unwrap().clone()) } + + /// Get schema-level custom metadata of the Ipc file + pub fn custom_metadata(&mut self) -> PolarsResult>> { + self.get_metadata()?; + Ok(self + .metadata + .as_ref() + .and_then(|meta| meta.custom_schema_metadata.clone())) + } + /// Stop reading when `n` rows are read. pub fn with_n_rows(mut self, num_rows: Option) -> Self { self.n_rows = num_rows; diff --git a/crates/polars-io/src/ipc/ipc_stream.rs b/crates/polars-io/src/ipc/ipc_stream.rs index 6393c639cf35..c3d2f353759b 100644 --- a/crates/polars-io/src/ipc/ipc_stream.rs +++ b/crates/polars-io/src/ipc/ipc_stream.rs @@ -36,6 +36,7 @@ use std::io::{Read, Write}; use std::path::PathBuf; +use arrow::datatypes::Metadata; use arrow::io::ipc::read::{StreamMetadata, StreamState}; use arrow::io::ipc::write::WriteOptions; use arrow::io::ipc::{read, write}; @@ -83,6 +84,12 @@ impl IpcStreamReader { pub fn arrow_schema(&mut self) -> PolarsResult { Ok(self.metadata()?.schema) } + + /// Get schema-level custom metadata of the Ipc Stream file + pub fn custom_metadata(&mut self) -> PolarsResult>> { + Ok(self.metadata()?.custom_schema_metadata.map(Arc::new)) + } + /// Stop reading when `n` rows are read. pub fn with_n_rows(mut self, num_rows: Option) -> Self { self.n_rows = num_rows; @@ -198,8 +205,17 @@ where /// fn example(df: &mut DataFrame) -> PolarsResult<()> { /// let mut file = File::create("file.ipc").expect("could not create file"); /// -/// IpcStreamWriter::new(&mut file) -/// .finish(df) +/// let mut writer = IpcStreamWriter::new(&mut file); +/// +/// let custom_metadata = [ +/// ("first_name".into(), "John".into()), +/// ("last_name".into(), "Doe".into()), +/// ] +/// .into_iter() +/// .collect(); +/// writer.set_custom_schema_metadata(Arc::new(custom_metadata)); +/// +/// writer.finish(df) /// } /// /// ``` @@ -208,6 +224,8 @@ pub struct IpcStreamWriter { writer: W, compression: Option, compat_level: CompatLevel, + /// Custom schema-level metadata + custom_schema_metadata: Option>, } use arrow::record_batch::RecordBatch; @@ -225,6 +243,11 @@ impl IpcStreamWriter { self.compat_level = compat_level; self } + + /// Sets custom schema metadata. Must be called before `start` is called + pub fn set_custom_schema_metadata(&mut self, custom_metadata: Arc) { + self.custom_schema_metadata = Some(custom_metadata); + } } impl SerWriter for IpcStreamWriter @@ -236,6 +259,7 @@ where writer, compression: None, compat_level: CompatLevel::oldest(), + custom_schema_metadata: None, } } @@ -247,6 +271,10 @@ where }, ); + if let Some(custom_metadata) = &self.custom_schema_metadata { + ipc_stream_writer.set_custom_schema_metadata(Arc::clone(custom_metadata)); + } + ipc_stream_writer.start(&df.schema().to_arrow(self.compat_level), None)?; let df = chunk_df_for_writing(df, 512 * 512)?; let iter = df.iter_chunks(self.compat_level, true); diff --git a/crates/polars-io/src/ipc/write.rs b/crates/polars-io/src/ipc/write.rs index f6277bc6c1bd..38b5d1d27fde 100644 --- a/crates/polars-io/src/ipc/write.rs +++ b/crates/polars-io/src/ipc/write.rs @@ -1,5 +1,6 @@ use std::io::Write; +use arrow::datatypes::Metadata; use arrow::io::ipc::write::{self, EncodedData, WriteOptions}; use polars_core::prelude::*; #[cfg(feature = "serde")] @@ -36,8 +37,16 @@ impl IpcWriterOptions { /// fn example(df: &mut DataFrame) -> PolarsResult<()> { /// let mut file = File::create("file.ipc").expect("could not create file"); /// -/// IpcWriter::new(&mut file) -/// .finish(df) +/// let mut writer = IpcWriter::new(&mut file); +/// +/// let custom_metadata = [ +/// ("first_name".into(), "John".into()), +/// ("last_name".into(), "Doe".into()), +/// ] +/// .into_iter() +/// .collect(); +/// writer.set_custom_schema_metadata(Arc::new(custom_metadata)); +/// writer.finish(df) /// } /// /// ``` @@ -48,6 +57,7 @@ pub struct IpcWriter { /// Polars' flavor of arrow. This might be temporary. pub(super) compat_level: CompatLevel, pub(super) parallel: bool, + pub(super) custom_schema_metadata: Option>, } impl IpcWriter { @@ -84,6 +94,11 @@ impl IpcWriter { compat_level: self.compat_level, }) } + + /// Sets custom schema metadata. Must be called before `start` is called + pub fn set_custom_schema_metadata(&mut self, custom_metadata: Arc) { + self.custom_schema_metadata = Some(custom_metadata); + } } impl SerWriter for IpcWriter @@ -96,6 +111,7 @@ where compression: None, compat_level: CompatLevel::newest(), parallel: true, + custom_schema_metadata: None, } } @@ -109,6 +125,10 @@ where compression: self.compression.map(|c| c.into()), }, )?; + if let Some(custom_metadata) = &self.custom_schema_metadata { + ipc_writer.set_custom_schema_metadata(Arc::clone(custom_metadata)); + } + if self.parallel { df.align_chunks_par(); } else { diff --git a/crates/polars-parquet/src/arrow/write/schema.rs b/crates/polars-parquet/src/arrow/write/schema.rs index 61c7f9fad218..3aa7c62aac6a 100644 --- a/crates/polars-parquet/src/arrow/write/schema.rs +++ b/crates/polars-parquet/src/arrow/write/schema.rs @@ -55,9 +55,9 @@ pub fn schema_to_metadata_key(schema: &ArrowSchema) -> KeyValue { .map(|field| convert_field(field.clone())) .map(|x| (x.name.clone(), x)) .collect(); - schema_to_bytes(&schema, &default_ipc_fields(schema.iter_values())) + schema_to_bytes(&schema, &default_ipc_fields(schema.iter_values()), None) } else { - schema_to_bytes(schema, &default_ipc_fields(schema.iter_values())) + schema_to_bytes(schema, &default_ipc_fields(schema.iter_values()), None) }; // manually prepending the length to the schema as arrow uses the legacy IPC format From 353bc0efbf387fa671049401d59c6d860f667b64 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Tue, 3 Dec 2024 08:02:11 +0000 Subject: [PATCH 06/20] docs: Update `by` param description for rolling_*_by functions (#19715) --- py-polars/polars/expr/expr.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index c993825bd223..ee75567be6ec 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -6201,7 +6201,9 @@ def rolling_min_by( Parameters ---------- by - This column must be of dtype Datetime or Date. + Should be ``DateTime``, ``Date``, ``UInt64``, ``UInt32``, ``Int64``, + or ``Int32`` data type (note that the integral ones require using `'i'` + in `window size`). window_size The length of the window. Can be a dynamic temporal size indicated by a timedelta or the following string language: @@ -6323,7 +6325,9 @@ def rolling_max_by( Parameters ---------- by - This column must be of dtype Datetime or Date. + Should be ``DateTime``, ``Date``, ``UInt64``, ``UInt32``, ``Int64``, + or ``Int32`` data type (note that the integral ones require using `'i'` + in `window size`). window_size The length of the window. Can be a dynamic temporal size indicated by a timedelta or the following string language: @@ -6471,7 +6475,9 @@ def rolling_mean_by( Parameters ---------- by - This column must be of dtype Datetime or Date. + Should be ``DateTime``, ``Date``, ``UInt64``, ``UInt32``, ``Int64``, + or ``Int32`` data type (note that the integral ones require using `'i'` + in `window size`). window_size The length of the window. Can be a dynamic temporal size indicated by a timedelta or the following string language: @@ -6650,7 +6656,9 @@ def rolling_sum_by( The number of values in the window that should be non-null before computing a result. by - This column must of dtype `{Date, Datetime}` + Should be ``DateTime``, ``Date``, ``UInt64``, ``UInt32``, ``Int64``, + or ``Int32`` data type (note that the integral ones require using `'i'` + in `window size`). closed : {'left', 'right', 'both', 'none'} Define which sides of the temporal interval are closed (inclusive), defaults to `'right'`. @@ -6775,7 +6783,9 @@ def rolling_std_by( Parameters ---------- by - This column must be of dtype Datetime or Date. + Should be ``DateTime``, ``Date``, ``UInt64``, ``UInt32``, ``Int64``, + or ``Int32`` data type (note that the integral ones require using `'i'` + in `window size`). window_size The length of the window. Can be a dynamic temporal size indicated by a timedelta or the following string language: @@ -6932,7 +6942,9 @@ def rolling_var_by( Parameters ---------- by - This column must be of dtype Datetime or Date. + Should be ``DateTime``, ``Date``, ``UInt64``, ``UInt32``, ``Int64``, + or ``Int32`` data type (note that the integral ones require using `'i'` + in `window size`). window_size The length of the window. Can be a dynamic temporal size indicated by a timedelta or the following string language: @@ -7088,7 +7100,9 @@ def rolling_median_by( Parameters ---------- by - This column must be of dtype Datetime or Date. + Should be ``DateTime``, ``Date``, ``UInt64``, ``UInt32``, ``Int64``, + or ``Int32`` data type (note that the integral ones require using `'i'` + in `window size`). window_size The length of the window. Can be a dynamic temporal size indicated by a timedelta or the following string language: @@ -7214,7 +7228,9 @@ def rolling_quantile_by( Parameters ---------- by - This column must be of dtype Datetime or Date. + Should be ``DateTime``, ``Date``, ``UInt64``, ``UInt32``, ``Int64``, + or ``Int32`` data type (note that the integral ones require using `'i'` + in `window size`). quantile Quantile between 0.0 and 1.0. interpolation : {'nearest', 'higher', 'lower', 'midpoint', 'linear'} From 62542038debd6c178b377f3397b15340e263f919 Mon Sep 17 00:00:00 2001 From: Dzenan Jupic <56133904+DzenanJupic@users.noreply.github.com> Date: Tue, 3 Dec 2024 09:20:31 +0100 Subject: [PATCH 07/20] fix: subtraction with underflow on empty FixedSizeBinaryArray (#20109) --- crates/polars-compute/src/cast/binary_to.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/crates/polars-compute/src/cast/binary_to.rs b/crates/polars-compute/src/cast/binary_to.rs index f5af77db9e88..071bdd85456c 100644 --- a/crates/polars-compute/src/cast/binary_to.rs +++ b/crates/polars-compute/src/cast/binary_to.rs @@ -202,12 +202,15 @@ pub fn fixed_size_binary_to_binview(from: &FixedSizeBinaryArray) -> BinaryViewAr // This is zero-copy for the buffer since split just increases the data since let mut buffer = from.values().clone(); let mut buffers = Vec::with_capacity(num_buffers); - for _ in 0..num_buffers - 1 { - let slice; - (slice, buffer) = buffer.split_at(split_point); - buffers.push(slice); + + if let Some(num_buffers) = num_buffers.checked_sub(1) { + for _ in 0..num_buffers { + let slice; + (slice, buffer) = buffer.split_at(split_point); + buffers.push(slice); + } + buffers.push(buffer); } - buffers.push(buffer); let mut iter = from.values_iter(); let iter = iter.by_ref(); From 7baabb225dc40660ede14b803cfe847af4c5983f Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Tue, 3 Dec 2024 10:55:25 +0100 Subject: [PATCH 08/20] fix: DataFrame `.get_column` after `drop_in_place` (#20120) --- crates/polars-core/src/frame/mod.rs | 6 +++++- py-polars/tests/unit/dataframe/test_df.py | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index b524bbd06d0f..2dc78741f977 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -1624,7 +1624,11 @@ impl DataFrame { pub fn get_column_index(&self, name: &str) -> Option { let schema = self.cached_schema.get_or_init(|| Arc::new(self.schema())); if let Some(idx) = schema.index_of(name) { - if self.get_columns()[idx].name() == name { + if self + .get_columns() + .get(idx) + .is_some_and(|c| c.name() == name) + { return Some(idx); } } diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 2953d018430f..b00ad6032f6b 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -3006,3 +3006,10 @@ def test_dataframe_creation_with_different_series_lengths_19795() -> None: match='could not create a new DataFrame: series "a" has length 2 while series "b" has length 1', ): pl.DataFrame({"a": [1, 2], "b": [1]}) + + +def test_get_column_after_drop_20119() -> None: + df = pl.DataFrame({"a": ["A"], "b": ["B"], "c": ["C"]}) + df.drop_in_place("a") + c = df.get_column("c") + assert_series_equal(c, pl.Series("c", ["C"])) From d86e44b9796d57b3740cbd7e0b01901958085e2f Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Tue, 3 Dec 2024 11:25:17 +0100 Subject: [PATCH 09/20] chore: Add a bunch more automated row encoding sortedness tests (#20056) --- crates/polars-row/src/decode.rs | 32 +- crates/polars-row/src/encode.rs | 104 +++-- py-polars/pyproject.toml | 1 + py-polars/tests/unit/test_row_encoding.py | 259 +----------- .../tests/unit/test_row_encoding_sort.py | 367 ++++++++++++++++++ 5 files changed, 436 insertions(+), 327 deletions(-) create mode 100644 py-polars/tests/unit/test_row_encoding_sort.py diff --git a/crates/polars-row/src/decode.rs b/crates/polars-row/src/decode.rs index adbf7ce12dc4..2accc53e6b93 100644 --- a/crates/polars-row/src/decode.rs +++ b/crates/polars-row/src/decode.rs @@ -4,13 +4,11 @@ use arrow::datatypes::ArrowDataType; use arrow::offset::OffsetsBuffer; use self::encode::fixed_size; -use self::fixed::decimal; use self::row::RowEncodingOptions; use self::variable::utf8::decode_str; use super::*; -use crate::fixed::boolean::decode_bool; -use crate::fixed::numeric::decode_primitive; -use crate::variable::binary::decode_binview; +use crate::fixed::{boolean, decimal, numeric, packed_u32}; +use crate::variable::{binary, no_order, utf8}; /// Decode `rows` into a arrow format /// # Safety @@ -93,13 +91,11 @@ fn dtype_and_data_to_encoded_item_len( match dtype { D::Binary | D::LargeBinary | D::BinaryView | D::Utf8 | D::LargeUtf8 | D::Utf8View if opt.contains(RowEncodingOptions::NO_ORDER) => - unsafe { crate::variable::no_order::len_from_buffer(data, opt) }, + unsafe { no_order::len_from_buffer(data, opt) }, D::Binary | D::LargeBinary | D::BinaryView => unsafe { - crate::variable::binary::encoded_item_len(data, opt) - }, - D::Utf8 | D::LargeUtf8 | D::Utf8View => unsafe { - crate::variable::utf8::len_from_buffer(data, opt) + binary::encoded_item_len(data, opt) }, + D::Utf8 | D::LargeUtf8 | D::Utf8View => unsafe { utf8::len_from_buffer(data, opt) }, D::List(list_field) | D::LargeList(list_field) => { let mut data = data; @@ -146,8 +142,8 @@ fn dtype_and_data_to_encoded_item_len( }; let num_bits = values.len().next_power_of_two().trailing_zeros() as usize + 1; - let str_len = unsafe { crate::variable::utf8::len_from_buffer(data, opt) }; - str_len + crate::fixed::packed_u32::len_from_num_bits(num_bits) + let str_len = unsafe { utf8::len_from_buffer(data, opt) }; + str_len + packed_u32::len_from_num_bits(num_bits) }, D::Union(_, _, _) => todo!(), @@ -205,8 +201,8 @@ unsafe fn decode_lexical_cat( let num_bits = values.len().next_power_of_two().trailing_zeros() as usize + 1; - let mut s = crate::fixed::packed_u32::decode(rows, opt, num_bits); - crate::fixed::packed_u32::decode(rows, opt, num_bits).with_validity(s.take_validity()) + let mut s = packed_u32::decode(rows, opt, num_bits); + packed_u32::decode(rows, opt, num_bits).with_validity(s.take_validity()) } unsafe fn decode( @@ -218,11 +214,11 @@ unsafe fn decode( use ArrowDataType as D; match dtype { D::Null => NullArray::new(D::Null, rows.len()).to_boxed(), - D::Boolean => decode_bool(rows, opt).to_boxed(), + D::Boolean => boolean::decode_bool(rows, opt).to_boxed(), D::Binary | D::LargeBinary | D::BinaryView | D::Utf8 | D::LargeUtf8 | D::Utf8View if opt.contains(RowEncodingOptions::NO_ORDER) => { - let array = crate::variable::no_order::decode_variable_no_order(rows, opt); + let array = no_order::decode_variable_no_order(rows, opt); if matches!(dtype, D::Utf8 | D::LargeUtf8 | D::Utf8View) { unsafe { array.to_utf8view_unchecked() }.to_boxed() @@ -230,7 +226,7 @@ unsafe fn decode( array.to_boxed() } }, - D::Binary | D::LargeBinary | D::BinaryView => decode_binview(rows, opt).to_boxed(), + D::Binary | D::LargeBinary | D::BinaryView => binary::decode_binview(rows, opt).to_boxed(), D::Utf8 | D::LargeUtf8 | D::Utf8View => decode_str(rows, opt).boxed(), D::Struct(fields) => { @@ -334,7 +330,7 @@ unsafe fn decode( if let Some(dict) = dict { return match dict { RowEncodingCatOrder::Physical(num_bits) => { - crate::fixed::packed_u32::decode(rows, opt, *num_bits).to_boxed() + packed_u32::decode(rows, opt, *num_bits).to_boxed() }, RowEncodingCatOrder::Lexical(values) => { decode_lexical_cat(rows, opt, values).to_boxed() @@ -345,7 +341,7 @@ unsafe fn decode( } with_match_arrow_primitive_type!(dt, |$T| { - decode_primitive::<$T>(rows, opt).to_boxed() + numeric::decode_primitive::<$T>(rows, opt).to_boxed() }) }, } diff --git a/crates/polars-row/src/encode.rs b/crates/polars-row/src/encode.rs index b5c0b00f2db5..702d2228d2f1 100644 --- a/crates/polars-row/src/encode.rs +++ b/crates/polars-row/src/encode.rs @@ -8,9 +8,9 @@ use arrow::bitmap::Bitmap; use arrow::datatypes::ArrowDataType; use arrow::types::Offset; -use crate::fixed::decimal; -use crate::fixed::numeric::FixedLengthEncoding; +use crate::fixed::{boolean, decimal, numeric, packed_u32}; use crate::row::{RowEncodingOptions, RowsEncoded}; +use crate::variable::{binary, no_order, utf8}; use crate::widths::RowWidths; use crate::{with_match_arrow_primitive_type, ArrayRef, RowEncodingCatOrder}; @@ -163,24 +163,21 @@ fn biniter_num_column_bytes( ) -> Encoder { if opt.contains(RowEncodingOptions::NO_ORDER) { match validity { - None => row_widths - .push_iter(iter.map(|v| crate::variable::no_order::len_from_item(Some(v), opt))), - Some(validity) => { - row_widths.push_iter(iter.zip(validity.iter()).map(|(v, is_valid)| { - crate::variable::no_order::len_from_item(is_valid.then_some(v), opt) - })) - }, + None => row_widths.push_iter(iter.map(|v| no_order::len_from_item(Some(v), opt))), + Some(validity) => row_widths.push_iter( + iter.zip(validity.iter()) + .map(|(v, is_valid)| no_order::len_from_item(is_valid.then_some(v), opt)), + ), } } else { match validity { None => row_widths.push_iter( iter.map(|v| crate::variable::binary::encoded_len_from_len(Some(v), opt)), ), - Some(validity) => { - row_widths.push_iter(iter.zip(validity.iter()).map(|(v, is_valid)| { - crate::variable::binary::encoded_len_from_len(is_valid.then_some(v), opt) - })) - }, + Some(validity) => row_widths.push_iter( + iter.zip(validity.iter()) + .map(|(v, is_valid)| binary::encoded_len_from_len(is_valid.then_some(v), opt)), + ), } }; @@ -199,23 +196,20 @@ fn striter_num_column_bytes( ) -> Encoder { if opt.contains(RowEncodingOptions::NO_ORDER) { match validity { - None => row_widths - .push_iter(iter.map(|v| crate::variable::no_order::len_from_item(Some(v), opt))), - Some(validity) => { - row_widths.push_iter(iter.zip(validity.iter()).map(|(v, is_valid)| { - crate::variable::no_order::len_from_item(is_valid.then_some(v), opt) - })) - }, + None => row_widths.push_iter(iter.map(|v| no_order::len_from_item(Some(v), opt))), + Some(validity) => row_widths.push_iter( + iter.zip(validity.iter()) + .map(|(v, is_valid)| no_order::len_from_item(is_valid.then_some(v), opt)), + ), } } else { match validity { None => row_widths .push_iter(iter.map(|v| crate::variable::utf8::len_from_item(Some(v), opt))), - Some(validity) => { - row_widths.push_iter(iter.zip(validity.iter()).map(|(v, is_valid)| { - crate::variable::utf8::len_from_item(is_valid.then_some(v), opt) - })) - }, + Some(validity) => row_widths.push_iter( + iter.zip(validity.iter()) + .map(|(v, is_valid)| utf8::len_from_item(is_valid.then_some(v), opt)), + ), } }; @@ -241,7 +235,7 @@ fn lexical_cat_num_column_bytes( } let num_bits = values.len().next_power_of_two().trailing_zeros() as usize + 1; - let idx_width = crate::fixed::packed_u32::len_from_num_bits(num_bits); + let idx_width = packed_u32::len_from_num_bits(num_bits); let values: Vec<&str> = values.values_iter().collect(); let mut sort_idxs = (0..values.len() as u32).collect::>(); @@ -516,14 +510,14 @@ unsafe fn encode_strs<'a>( offsets: &mut [usize], ) { if opt.contains(RowEncodingOptions::NO_ORDER) { - crate::variable::no_order::encode_variable_no_order( + no_order::encode_variable_no_order( buffer, iter.map(|v| v.map(str::as_bytes)), opt, offsets, ); } else { - crate::variable::utf8::encode_str(buffer, iter, opt, offsets); + utf8::encode_str(buffer, iter, opt, offsets); } } @@ -534,9 +528,9 @@ unsafe fn encode_bins<'a>( offsets: &mut [usize], ) { if opt.contains(RowEncodingOptions::NO_ORDER) { - crate::variable::no_order::encode_variable_no_order(buffer, iter, opt, offsets); + no_order::encode_variable_no_order(buffer, iter, opt, offsets); } else { - crate::variable::binary::encode_iter(buffer, iter, opt, offsets); + binary::encode_iter(buffer, iter, opt, offsets); } } @@ -553,7 +547,7 @@ unsafe fn encode_flat_array( D::Null => {}, D::Boolean => { let array = array.as_any().downcast_ref::().unwrap(); - crate::fixed::boolean::encode_bool(buffer, array.iter(), opt, offsets); + boolean::encode_bool(buffer, array.iter(), opt, offsets); }, // Needs to happen before numeric arm. @@ -578,7 +572,7 @@ unsafe fn encode_flat_array( match dict { RowEncodingCatOrder::Physical(num_bits) => { - crate::fixed::packed_u32::encode(buffer, keys, opt, offsets, *num_bits) + packed_u32::encode(buffer, keys, opt, offsets, *num_bits) }, _ => unreachable!(), } @@ -588,7 +582,7 @@ unsafe fn encode_flat_array( with_match_arrow_primitive_type!(dt, |$T| { let array = array.as_any().downcast_ref::>().unwrap(); - crate::fixed::numeric::encode(buffer, array, opt, offsets); + numeric::encode(buffer, array, opt, offsets); }) }, @@ -798,14 +792,14 @@ unsafe fn encode_array( .unwrap(); let num_bits = sort_idxs.len().next_power_of_two().trailing_zeros() as usize + 1; - crate::fixed::packed_u32::encode_iter( + packed_u32::encode_iter( buffer, keys.iter().map(|k| k.map(|&k| sort_idxs[k as usize])), opt, offsets, num_bits, ); - crate::fixed::packed_u32::encode_slice(buffer, keys.values(), opt, offsets, num_bits); + packed_u32::encode_slice(buffer, keys.values(), opt, offsets, num_bits); }, } } @@ -839,28 +833,33 @@ unsafe fn encode_validity( } pub fn fixed_size(dtype: &ArrowDataType, dict: Option<&RowEncodingCatOrder>) -> Option { - use ArrowDataType::*; + use numeric::FixedLengthEncoding; + use ArrowDataType as D; Some(match dtype { - UInt8 => u8::ENCODED_LEN, - UInt16 => u16::ENCODED_LEN, - UInt32 => match dict { + D::Null => 0, + D::Boolean => 1, + + D::UInt8 => u8::ENCODED_LEN, + D::UInt16 => u16::ENCODED_LEN, + D::UInt32 => match dict { None => u32::ENCODED_LEN, Some(RowEncodingCatOrder::Physical(num_bits)) => { - crate::fixed::packed_u32::len_from_num_bits(*num_bits) + packed_u32::len_from_num_bits(*num_bits) }, _ => return None, }, - UInt64 => u64::ENCODED_LEN, - Int8 => i8::ENCODED_LEN, - Int16 => i16::ENCODED_LEN, - Int32 => i32::ENCODED_LEN, - Int64 => i64::ENCODED_LEN, - Decimal(precision, _) => decimal::len_from_precision(*precision), - Float32 => f32::ENCODED_LEN, - Float64 => f64::ENCODED_LEN, - Boolean => 1, - FixedSizeList(f, width) => 1 + width * fixed_size(f.dtype(), dict)?, - Struct(fs) => match dict { + D::UInt64 => u64::ENCODED_LEN, + + D::Int8 => i8::ENCODED_LEN, + D::Int16 => i16::ENCODED_LEN, + D::Int32 => i32::ENCODED_LEN, + D::Int64 => i64::ENCODED_LEN, + + D::Decimal(precision, _) => decimal::len_from_precision(*precision), + D::Float32 => f32::ENCODED_LEN, + D::Float64 => f64::ENCODED_LEN, + D::FixedSizeList(f, width) => 1 + width * fixed_size(f.dtype(), dict)?, + D::Struct(fs) => match dict { None => { let mut sum = 0; for f in fs { @@ -877,7 +876,6 @@ pub fn fixed_size(dtype: &ArrowDataType, dict: Option<&RowEncodingCatOrder>) -> }, _ => unreachable!(), }, - Null => 0, _ => return None, }) } diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index fa060fb48810..4b292b86629d 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -192,6 +192,7 @@ ignore = [ "RUF005", # Consider expression instead of concatenation "SIM102", # Use a single `if` statement instead of nested `if` statements "SIM108", # Use ternary operator + "SIM114", # Combine `if` branches "TD002", # Missing author in TODO "TD003", # Missing issue link on the line following this TODO "TRY003", # Avoid specifying long messages outside the exception class diff --git a/py-polars/tests/unit/test_row_encoding.py b/py-polars/tests/unit/test_row_encoding.py index 1a31d280c77b..7fdf2c1a9a80 100644 --- a/py-polars/tests/unit/test_row_encoding.py +++ b/py-polars/tests/unit/test_row_encoding.py @@ -1,14 +1,14 @@ from __future__ import annotations from decimal import Decimal as D -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING import pytest from hypothesis import given import polars as pl -from polars.testing import assert_frame_equal, assert_series_equal -from polars.testing.parametric import column, dataframes +from polars.testing import assert_frame_equal +from polars.testing.parametric import dataframes if TYPE_CHECKING: from polars._typing import PolarsDataType @@ -326,256 +326,3 @@ def test_int_after_null() -> None: ), [(False, True, False), (False, True, False)], ) - - -def assert_order_dataframe( - lhs: pl.DataFrame, - rhs: pl.DataFrame, - order: list[Literal["lt", "eq", "gt"]], - *, - descending: bool = False, - nulls_last: bool = False, -) -> None: - field = (descending, nulls_last, False) - l_re = lhs._row_encode([field] * lhs.width).cast(pl.Binary) - r_re = rhs._row_encode([field] * rhs.width).cast(pl.Binary) - - l_lt_r_s = "gt" if descending else "lt" - l_gt_r_s = "lt" if descending else "gt" - - assert_series_equal( - l_re < r_re, pl.Series([o == l_lt_r_s for o in order]), check_names=False - ) - assert_series_equal( - l_re == r_re, pl.Series([o == "eq" for o in order]), check_names=False - ) - assert_series_equal( - l_re > r_re, pl.Series([o == l_gt_r_s for o in order]), check_names=False - ) - - -def assert_order_series( - lhs: pl.series.series.ArrayLike, - rhs: pl.series.series.ArrayLike, - dtype: pl._typing.PolarsDataType, - order: list[Literal["lt", "eq", "gt"]], - *, - descending: bool = False, - nulls_last: bool = False, -) -> None: - lhs = pl.Series("lhs", lhs, dtype).to_frame() - rhs = pl.Series("rhs", rhs, dtype).to_frame() - assert_order_dataframe( - lhs, rhs, order, descending=descending, nulls_last=nulls_last - ) - - -def parametric_order_base(df: pl.DataFrame) -> None: - lhs = df.get_columns()[0] - rhs = df.get_columns()[1] - - field = (False, False, False) - lhs_re = lhs.to_frame()._row_encode([field]).cast(pl.Binary) - rhs_re = rhs.to_frame()._row_encode([field]).cast(pl.Binary) - - assert_series_equal(lhs < rhs, lhs_re < rhs_re, check_names=False) - assert_series_equal(lhs == rhs, lhs_re == rhs_re, check_names=False) - assert_series_equal(lhs > rhs, lhs_re > rhs_re, check_names=False) - - field = (True, False, False) - lhs_re = lhs.to_frame()._row_encode([field]).cast(pl.Binary) - rhs_re = rhs.to_frame()._row_encode([field]).cast(pl.Binary) - - assert_series_equal(lhs > rhs, lhs_re < rhs_re, check_names=False) - assert_series_equal(lhs == rhs, lhs_re == rhs_re, check_names=False) - assert_series_equal(lhs < rhs, lhs_re > rhs_re, check_names=False) - - -@given( - df=dataframes([column(dtype=pl.Int32), column(dtype=pl.Int32)], allow_null=False) -) -def test_parametric_int_order(df: pl.DataFrame) -> None: - parametric_order_base(df) - - -@given( - df=dataframes([column(dtype=pl.UInt32), column(dtype=pl.UInt32)], allow_null=False) -) -def test_parametric_uint_order(df: pl.DataFrame) -> None: - parametric_order_base(df) - - -@given( - df=dataframes([column(dtype=pl.String), column(dtype=pl.String)], allow_null=False) -) -def test_parametric_string_order(df: pl.DataFrame) -> None: - parametric_order_base(df) - - -@given( - df=dataframes([column(dtype=pl.Binary), column(dtype=pl.Binary)], allow_null=False) -) -def test_parametric_binary_order(df: pl.DataFrame) -> None: - parametric_order_base(df) - - -def test_order_bool() -> None: - dtype = pl.Boolean - assert_order_series( - [None, False, True], [True, False, None], dtype, ["lt", "eq", "gt"] - ) - assert_order_series( - [None, False, True], - [True, False, None], - dtype, - ["gt", "eq", "lt"], - nulls_last=True, - ) - - assert_order_series( - [False, False, True, True], - [True, False, True, False], - dtype, - ["lt", "eq", "eq", "gt"], - ) - assert_order_series( - [False, False, True, True], - [True, False, True, False], - dtype, - ["lt", "eq", "eq", "gt"], - descending=True, - ) - - -def test_order_int() -> None: - dtype = pl.Int32 - assert_order_series([1, 2, 3], [3, 2, 1], dtype, ["lt", "eq", "gt"]) - assert_order_series([-1, 0, 1], [1, 0, -1], dtype, ["lt", "eq", "gt"]) - assert_order_series([None], [None], dtype, ["eq"]) - assert_order_series([None], [1], dtype, ["lt"]) - assert_order_series([None], [1], dtype, ["gt"], nulls_last=True) - - -def test_order_uint() -> None: - dtype = pl.UInt32 - assert_order_series([1, 2, 3], [3, 2, 1], dtype, ["lt", "eq", "gt"]) - assert_order_series([None], [None], dtype, ["eq"]) - assert_order_series([None], [1], dtype, ["lt"]) - assert_order_series([None], [1], dtype, ["gt"], nulls_last=True) - - -def test_order_str() -> None: - dtype = pl.String - assert_order_series(["a", "b", "c"], ["c", "b", "a"], dtype, ["lt", "eq", "gt"]) - assert_order_series( - ["a", "b", "c"], ["c", "b", "a"], dtype, ["lt", "eq", "gt"], descending=True - ) - assert_order_series( - ["a", "aa", "aaa"], ["aaa", "aa", "a"], dtype, ["lt", "eq", "gt"] - ) - assert_order_series( - ["a", "aa", "aaa"], - ["aaa", "aa", "a"], - dtype, - ["lt", "eq", "gt"], - descending=True, - ) - assert_order_series(["", "a", "aa"], ["aa", "a", ""], dtype, ["lt", "eq", "gt"]) - assert_order_series( - ["", "a", "aa"], ["aa", "a", ""], dtype, ["lt", "eq", "gt"], descending=True - ) - assert_order_series([None], [None], dtype, ["eq"]) - assert_order_series([None], ["a"], dtype, ["lt"]) - assert_order_series([None], ["a"], dtype, ["gt"], nulls_last=True) - - -def test_order_bin() -> None: - dtype = pl.Binary - assert_order_series( - [b"a", b"b", b"c"], [b"c", b"b", b"a"], dtype, ["lt", "eq", "gt"] - ) - assert_order_series( - [b"a", b"b", b"c"], - [b"c", b"b", b"a"], - dtype, - ["lt", "eq", "gt"], - descending=True, - ) - assert_order_series( - [b"a", b"aa", b"aaa"], [b"aaa", b"aa", b"a"], dtype, ["lt", "eq", "gt"] - ) - assert_order_series( - [b"a", b"aa", b"aaa"], - [b"aaa", b"aa", b"a"], - dtype, - ["lt", "eq", "gt"], - descending=True, - ) - assert_order_series( - [b"", b"a", b"aa"], [b"aa", b"a", b""], dtype, ["lt", "eq", "gt"] - ) - assert_order_series( - [b"", b"a", b"aa"], - [b"aa", b"a", b""], - dtype, - ["lt", "eq", "gt"], - descending=True, - ) - assert_order_series([None], [None], dtype, ["eq"]) - assert_order_series([None], [b"a"], dtype, ["lt"]) - assert_order_series([None], [b"a"], dtype, ["gt"], nulls_last=True) - - -def test_order_list() -> None: - dtype = pl.List(pl.Int32) - assert_order_series([[1, 2, 3]], [[3, 2, 1]], dtype, ["lt"]) - assert_order_series([[-1, 0, 1]], [[1, 0, -1]], dtype, ["lt"]) - assert_order_series([None], [None], dtype, ["eq"]) - assert_order_series([None], [[1, 2, 3]], dtype, ["lt"]) - assert_order_series([None], [[1, 2, 3]], dtype, ["gt"], nulls_last=True) - assert_order_series([[None, 2, 3]], [[None, 2, 1]], dtype, ["gt"]) - - assert_order_series([[]], [[None]], dtype, ["lt"]) - assert_order_series([[]], [[None]], dtype, ["lt"], descending=True) - assert_order_series([[]], [[1]], dtype, ["lt"]) - assert_order_series([[]], [[1]], dtype, ["lt"], descending=True) - assert_order_series([[1]], [[1, 2]], dtype, ["lt"]) - assert_order_series([[1]], [[1, 2]], dtype, ["lt"], descending=True) - - -def test_order_array() -> None: - dtype = pl.Array(pl.Int32, 3) - assert_order_series([[1, 2, 3]], [[3, 2, 1]], dtype, ["lt"]) - assert_order_series([[-1, 0, 1]], [[1, 0, -1]], dtype, ["lt"]) - assert_order_series([None], [None], dtype, ["eq"]) - assert_order_series([None], [[1, 2, 3]], dtype, ["lt"]) - assert_order_series([None], [[1, 2, 3]], dtype, ["gt"], nulls_last=True) - assert_order_series([[None, 2, 3]], [[None, 2, 1]], dtype, ["gt"]) - - -def test_order_masked_array() -> None: - dtype = pl.Array(pl.Int32, 3) - lhs = pl.Series("l", [1, 2, 3], pl.Int32).replace(1, None).reshape((1, 3)) - rhs = pl.Series("r", [3, 2, 1], pl.Int32).replace(3, None).reshape((1, 3)) - assert_order_series(lhs, rhs, dtype, ["gt"]) - - -def test_order_masked_struct() -> None: - dtype = pl.Array(pl.Int32, 3) - lhs = pl.Series("l", [1, 2, 3], pl.Int32).replace(1, None).reshape((1, 3)) - rhs = pl.Series("r", [3, 2, 1], pl.Int32).replace(3, None).reshape((1, 3)) - assert_order_series( - lhs.to_frame().to_struct(), rhs.to_frame().to_struct(), dtype, ["gt"] - ) - - -def test_order_enum() -> None: - dtype = pl.Enum(["a", "x", "0"]) - - assert_order_series(["a", "x", "0"], ["0", "x", "a"], dtype, ["lt", "eq", "gt"]) - assert_order_series( - ["a", "x", "0"], ["0", "x", "a"], dtype, ["lt", "eq", "gt"], descending=True - ) - assert_order_series([None], [None], dtype, ["eq"]) - assert_order_series([None], ["a"], dtype, ["lt"]) - assert_order_series([None], ["a"], dtype, ["gt"], nulls_last=True) diff --git a/py-polars/tests/unit/test_row_encoding_sort.py b/py-polars/tests/unit/test_row_encoding_sort.py new file mode 100644 index 000000000000..7e85ccd6ea7c --- /dev/null +++ b/py-polars/tests/unit/test_row_encoding_sort.py @@ -0,0 +1,367 @@ +# mypy: disable-error-code="valid-type" + +from __future__ import annotations + +import datetime +import decimal +import functools +from typing import Any, Literal, Optional, Union + +import pytest +from hypothesis import given + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import column, dataframes, series + +Element = Optional[ + Union[ + bool, + int, + float, + str, + decimal.Decimal, + datetime.date, + datetime.datetime, + datetime.time, + datetime.timedelta, + list[Any], + dict[Any, Any], + ] +] +OrderSign = Literal[-1, 0, 1] + + +def elem_order_sign( + lhs: Element, rhs: Element, *, descending: bool, nulls_last: bool +) -> OrderSign: + if isinstance(lhs, pl.Series) and isinstance(rhs, pl.Series): + if lhs.equals(rhs): + return 0 + + lhs = list(lhs) + rhs = list(rhs) + + if lhs == rhs: + return 0 + elif lhs is None: + return 1 if nulls_last else -1 + elif rhs is None: + return -1 if nulls_last else 1 + elif isinstance(lhs, bool) and isinstance(rhs, bool): + return -1 if (lhs < rhs) ^ descending else 1 + elif isinstance(lhs, datetime.date) and isinstance(rhs, datetime.date): + return -1 if (lhs < rhs) ^ descending else 1 + elif isinstance(lhs, datetime.datetime) and isinstance(rhs, datetime.datetime): + return -1 if (lhs < rhs) ^ descending else 1 + elif isinstance(lhs, datetime.time) and isinstance(rhs, datetime.time): + return -1 if (lhs < rhs) ^ descending else 1 + elif isinstance(lhs, datetime.timedelta) and isinstance(rhs, datetime.timedelta): + return -1 if (lhs < rhs) ^ descending else 1 + elif isinstance(lhs, decimal.Decimal) and isinstance(rhs, decimal.Decimal): + return -1 if (lhs < rhs) ^ descending else 1 + elif isinstance(lhs, int) and isinstance(rhs, int): + return -1 if (lhs < rhs) ^ descending else 1 + elif isinstance(lhs, float) and isinstance(rhs, float): + return -1 if (lhs < rhs) ^ descending else 1 + elif isinstance(lhs, bytes) and isinstance(rhs, bytes): + lhs_b: bytes = lhs + rhs_b: bytes = rhs + + for lh, rh in zip(lhs_b, rhs_b): + o = elem_order_sign(lh, rh, descending=descending, nulls_last=nulls_last) + if o != 0: + return o + + if len(lhs_b) == len(rhs_b): + return 0 + else: + return -1 if (len(lhs_b) < len(rhs_b)) ^ descending else 1 + elif isinstance(lhs, str) and isinstance(rhs, str): + return -1 if (lhs < rhs) ^ descending else 1 + elif isinstance(lhs, list) and isinstance(rhs, list): + for lh, rh in zip(lhs, rhs): + o = elem_order_sign(lh, rh, descending=descending, nulls_last=nulls_last) + if o != 0: + return o + + if len(lhs) == len(rhs): + return 0 + else: + return -1 if (len(lhs) < len(rhs)) ^ descending else 1 + elif isinstance(lhs, dict) and isinstance(rhs, dict): + assert len(lhs) == len(rhs) + + for lh, rh in zip(lhs.values(), rhs.values()): + o = elem_order_sign(lh, rh, descending=descending, nulls_last=nulls_last) + if o != 0: + return o + + return 0 + else: + pytest.fail("type mismatch") + + +def tuple_order( + lhs: tuple[Element, ...], + rhs: tuple[Element, ...], + *, + descending: list[bool], + nulls_last: list[bool], +) -> OrderSign: + assert len(lhs) == len(rhs) + + for lh, rh, dsc, nl in zip(lhs, rhs, descending, nulls_last): + o = elem_order_sign(lh, rh, descending=dsc, nulls_last=nl) + if o != 0: + return o + + return 0 + + +@given( + s=series( + excluded_dtypes=[pl.Categorical], + max_size=5, + ) +) +def test_series_sort_parametric(s: pl.Series) -> None: + for descending in [False, True]: + for nulls_last in [False, True]: + fields = [(descending, nulls_last, False)] + + def cmp( + lhs: Element, + rhs: Element, + descending: bool = descending, + nulls_last: bool = nulls_last, + ) -> OrderSign: + return elem_order_sign( + lhs, rhs, descending=descending, nulls_last=nulls_last + ) + + rows = list(s) + rows.sort(key=functools.cmp_to_key(cmp)) # type: ignore[arg-type, unused-ignore] + + re = s.to_frame()._row_encode(fields) + re_sorted = re.sort() + re_decoded = re_sorted._row_decode([("s", s.dtype)], fields) + + assert_series_equal( + pl.Series("s", rows, s.dtype), re_decoded.get_column("s") + ) + + +@given( + df=dataframes( + excluded_dtypes=[pl.Categorical], + max_cols=3, + max_size=5, + ) +) +def test_df_sort_parametric(df: pl.DataFrame) -> None: + for i in range(4**df.width): + descending = [((i // (4**j)) % 4) in [2, 3] for j in range(df.width)] + nulls_last = [((i // (4**j)) % 4) in [1, 3] for j in range(df.width)] + + fields = [ + (descending, nulls_last, False) + for (descending, nulls_last) in zip(descending, nulls_last) + ] + + def cmp( + lhs: tuple[Element, ...], + rhs: tuple[Element, ...], + descending: list[bool] = descending, + nulls_last: list[bool] = nulls_last, + ) -> OrderSign: + return tuple_order(lhs, rhs, descending=descending, nulls_last=nulls_last) + + rows = df.rows() + rows.sort(key=functools.cmp_to_key(cmp)) # type: ignore[arg-type, unused-ignore] + + re = df._row_encode(fields) + re_sorted = re.sort() + re_decoded = re_sorted._row_decode(df.schema.items(), fields) + + assert_frame_equal(pl.DataFrame(rows, df.schema, orient="row"), re_decoded) + + +def assert_order_series( + lhs: pl.series.series.ArrayLike, + rhs: pl.series.series.ArrayLike, + dtype: pl._typing.PolarsDataType, +) -> None: + lhs_df = pl.Series("lhs", lhs, dtype).to_frame() + rhs_df = pl.Series("rhs", rhs, dtype).to_frame() + + for descending in [False, True]: + for nulls_last in [False, True]: + field = (descending, nulls_last, False) + l_re = lhs_df._row_encode([field]).cast(pl.Binary) + r_re = rhs_df._row_encode([field]).cast(pl.Binary) + + order = [ + elem_order_sign( + lh[0], rh[0], descending=descending, nulls_last=nulls_last + ) + for (lh, rh) in zip(lhs_df.rows(), rhs_df.rows()) + ] + + assert_series_equal( + l_re < r_re, pl.Series([o == -1 for o in order]), check_names=False + ) + assert_series_equal( + l_re == r_re, pl.Series([o == 0 for o in order]), check_names=False + ) + assert_series_equal( + l_re > r_re, pl.Series([o == 1 for o in order]), check_names=False + ) + + +def parametric_order_base(df: pl.DataFrame) -> None: + lhs = df.get_columns()[0] + rhs = df.get_columns()[1] + + field = (False, False, False) + lhs_re = lhs.to_frame()._row_encode([field]).cast(pl.Binary) + rhs_re = rhs.to_frame()._row_encode([field]).cast(pl.Binary) + + assert_series_equal(lhs < rhs, lhs_re < rhs_re, check_names=False) + assert_series_equal(lhs == rhs, lhs_re == rhs_re, check_names=False) + assert_series_equal(lhs > rhs, lhs_re > rhs_re, check_names=False) + + field = (True, False, False) + lhs_re = lhs.to_frame()._row_encode([field]).cast(pl.Binary) + rhs_re = rhs.to_frame()._row_encode([field]).cast(pl.Binary) + + assert_series_equal(lhs > rhs, lhs_re < rhs_re, check_names=False) + assert_series_equal(lhs == rhs, lhs_re == rhs_re, check_names=False) + assert_series_equal(lhs < rhs, lhs_re > rhs_re, check_names=False) + + +@given( + df=dataframes([column(dtype=pl.Int32), column(dtype=pl.Int32)], allow_null=False) +) +def test_parametric_int_order(df: pl.DataFrame) -> None: + parametric_order_base(df) + + +@given( + df=dataframes([column(dtype=pl.UInt32), column(dtype=pl.UInt32)], allow_null=False) +) +def test_parametric_uint_order(df: pl.DataFrame) -> None: + parametric_order_base(df) + + +@given( + df=dataframes([column(dtype=pl.String), column(dtype=pl.String)], allow_null=False) +) +def test_parametric_string_order(df: pl.DataFrame) -> None: + parametric_order_base(df) + + +@given( + df=dataframes([column(dtype=pl.Binary), column(dtype=pl.Binary)], allow_null=False) +) +def test_parametric_binary_order(df: pl.DataFrame) -> None: + parametric_order_base(df) + + +def test_order_bool() -> None: + dtype = pl.Boolean + assert_order_series([None, False, True], [True, False, None], dtype) + assert_order_series( + [None, False, True], + [True, False, None], + dtype, + ) + + assert_order_series( + [False, False, True, True], + [True, False, True, False], + dtype, + ) + assert_order_series( + [False, False, True, True], + [True, False, True, False], + dtype, + ) + + +def test_order_int() -> None: + dtype = pl.Int32 + assert_order_series([1, 2, 3], [3, 2, 1], dtype) + assert_order_series([-1, 0, 1], [1, 0, -1], dtype) + assert_order_series([None], [None], dtype) + assert_order_series([None], [1], dtype) + + +def test_order_uint() -> None: + dtype = pl.UInt32 + assert_order_series([1, 2, 3], [3, 2, 1], dtype) + assert_order_series([None], [None], dtype) + assert_order_series([None], [1], dtype) + + +def test_order_str() -> None: + dtype = pl.String + assert_order_series(["a", "b", "c"], ["c", "b", "a"], dtype) + assert_order_series(["a", "aa", "aaa"], ["aaa", "aa", "a"], dtype) + assert_order_series(["", "a", "aa"], ["aa", "a", ""], dtype) + assert_order_series([None], [None], dtype) + assert_order_series([None], ["a"], dtype) + + +def test_order_bin() -> None: + dtype = pl.Binary + assert_order_series([b"a", b"b", b"c"], [b"c", b"b", b"a"], dtype) + assert_order_series([b"a", b"aa", b"aaa"], [b"aaa", b"aa", b"a"], dtype) + assert_order_series([b"", b"a", b"aa"], [b"aa", b"a", b""], dtype) + assert_order_series([None], [None], dtype) + assert_order_series([None], [b"a"], dtype) + assert_order_series([None], [b"a"], dtype) + + +def test_order_list() -> None: + dtype = pl.List(pl.Int32) + assert_order_series([[1, 2, 3]], [[3, 2, 1]], dtype) + assert_order_series([[-1, 0, 1]], [[1, 0, -1]], dtype) + assert_order_series([None], [None], dtype) + assert_order_series([None], [[1, 2, 3]], dtype) + assert_order_series([[None, 2, 3]], [[None, 2, 1]], dtype) + + assert_order_series([[]], [[None]], dtype) + assert_order_series([[]], [[1]], dtype) + assert_order_series([[1]], [[1, 2]], dtype) + + +def test_order_array() -> None: + dtype = pl.Array(pl.Int32, 3) + assert_order_series([[1, 2, 3]], [[3, 2, 1]], dtype) + assert_order_series([[-1, 0, 1]], [[1, 0, -1]], dtype) + assert_order_series([None], [None], dtype) + assert_order_series([None], [[1, 2, 3]], dtype) + assert_order_series([[None, 2, 3]], [[None, 2, 1]], dtype) + + +def test_order_masked_array() -> None: + dtype = pl.Array(pl.Int32, 3) + lhs = pl.Series("l", [1, 2, 3], pl.Int32).replace(1, None).reshape((1, 3)) + rhs = pl.Series("r", [3, 2, 1], pl.Int32).replace(3, None).reshape((1, 3)) + assert_order_series(lhs, rhs, dtype) + + +def test_order_masked_struct() -> None: + dtype = pl.Array(pl.Int32, 3) + lhs = pl.Series("l", [1, 2, 3], pl.Int32).replace(1, None).reshape((1, 3)) + rhs = pl.Series("r", [3, 2, 1], pl.Int32).replace(3, None).reshape((1, 3)) + assert_order_series(lhs.to_frame().to_struct(), rhs.to_frame().to_struct(), dtype) + + +def test_order_enum() -> None: + dtype = pl.Enum(["a", "b", "c"]) + + assert_order_series(["a", "b", "c"], ["c", "b", "a"], dtype) + assert_order_series([None], [None], dtype) + assert_order_series([None], ["a"], dtype) From c5a8efafe89e2f2b023dbf855b7e1e0510241268 Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Tue, 3 Dec 2024 23:59:17 +1100 Subject: [PATCH 10/20] fix: Incorrect aggregation of empty groups after slice (#20127) --- crates/polars-core/src/frame/column/mod.rs | 3 ++- crates/polars-expr/src/expressions/slice.rs | 5 +++++ .../aggregation/test_aggregations.py | 18 ++++++++++++++++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/crates/polars-core/src/frame/column/mod.rs b/crates/polars-core/src/frame/column/mod.rs index 6e2a36be8b75..39f363f6463d 100644 --- a/crates/polars-core/src/frame/column/mod.rs +++ b/crates/polars-core/src/frame/column/mod.rs @@ -657,7 +657,8 @@ impl Column { let mut s = scalar_col.take_materialized_series().rechunk(); // SAFETY: We perform a compute_len afterwards. let chunks = unsafe { s.chunks_mut() }; - chunks[0].with_validity(Some(validity)); + let arr = &mut chunks[0]; + *arr = arr.with_validity(Some(validity)); s.compute_len(); s.into_column() diff --git a/crates/polars-expr/src/expressions/slice.rs b/crates/polars-expr/src/expressions/slice.rs index 0c2688d7999a..72bb6376466c 100644 --- a/crates/polars-expr/src/expressions/slice.rs +++ b/crates/polars-expr/src/expressions/slice.rs @@ -110,6 +110,11 @@ impl PhysicalExpr for SliceExpr { .collect::>>() })?; let mut ac = results.pop().unwrap(); + + if let AggState::AggregatedScalar(_) = ac.agg_state() { + polars_bail!(InvalidOperation: "cannot slice() an aggregated scalar value") + } + let mut ac_length = results.pop().unwrap(); let mut ac_offset = results.pop().unwrap(); diff --git a/py-polars/tests/unit/operations/aggregation/test_aggregations.py b/py-polars/tests/unit/operations/aggregation/test_aggregations.py index 1c8bcabba74c..5d4826da6d62 100644 --- a/py-polars/tests/unit/operations/aggregation/test_aggregations.py +++ b/py-polars/tests/unit/operations/aggregation/test_aggregations.py @@ -742,3 +742,21 @@ def test_sort_by_over_multiple_nulls_last() -> None: } ) assert_frame_equal(out, expected) + + +def test_slice_after_agg_raises() -> None: + with pytest.raises( + InvalidOperationError, match=r"cannot slice\(\) an aggregated scalar value" + ): + pl.select(a=1, b=1).group_by("a").agg(pl.col("b").first().slice(99, 0)) + + +def test_agg_scalar_empty_groups_20115() -> None: + assert_frame_equal( + ( + pl.DataFrame({"key": [123], "value": [456]}) + .group_by("key") + .agg(pl.col("value").slice(1, 1).first()) + ), + pl.select(key=pl.lit(123, pl.Int64), value=pl.lit(None, pl.Int64)), + ) From 394410bbd4108937256aae0eaafbbd37e7675421 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Tue, 3 Dec 2024 13:59:29 +0100 Subject: [PATCH 11/20] fix: Properly coerce types in lists (#20126) --- .../src/series/arithmetic/borrowed.rs | 18 ++++++++---------- py-polars/tests/unit/datatypes/test_struct.py | 7 ++++++- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index 5fa31e7847d3..d52916364b0e 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -347,20 +347,18 @@ pub(crate) fn coerce_lhs_rhs<'a>( return Ok(result); } let (left_dtype, right_dtype) = (lhs.dtype(), rhs.dtype()); - let leaf_super_dtype = match (left_dtype, right_dtype) { - #[cfg(feature = "dtype-struct")] - (DataType::Struct(_), DataType::Struct(_)) => { - return Ok((Cow::Borrowed(lhs), Cow::Borrowed(rhs))) - }, - _ => try_get_supertype(left_dtype.leaf_dtype(), right_dtype.leaf_dtype())?, - }; + let leaf_super_dtype = try_get_supertype(left_dtype.leaf_dtype(), right_dtype.leaf_dtype())?; let mut new_left_dtype = left_dtype.cast_leaf(leaf_super_dtype.clone()); let mut new_right_dtype = right_dtype.cast_leaf(leaf_super_dtype); - // Cast List<->Array to List - if (left_dtype.is_list() && right_dtype.is_array()) - || (left_dtype.is_array() && right_dtype.is_list()) + // Correct the list and array types + // + // This also casts Lists <-> Array. + if left_dtype.is_list() + || right_dtype.is_list() + || left_dtype.is_array() + || right_dtype.is_array() { new_left_dtype = try_get_supertype(&new_left_dtype, &new_right_dtype)?; new_right_dtype = new_left_dtype.clone(); diff --git a/py-polars/tests/unit/datatypes/test_struct.py b/py-polars/tests/unit/datatypes/test_struct.py index a2588430ecd2..7ead49773384 100644 --- a/py-polars/tests/unit/datatypes/test_struct.py +++ b/py-polars/tests/unit/datatypes/test_struct.py @@ -3,7 +3,7 @@ import io from dataclasses import dataclass from datetime import datetime, time -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Any, Callable import pandas as pd import pyarrow as pa @@ -1214,3 +1214,8 @@ def test_struct_field_list_eval_17356() -> None: [{"name": "ALICE", "age": 65, "car": "Mazda"}], ], } + + +@pytest.mark.parametrize("data", [[1], [[1]], {"a": 1}, [{"a": 1}]]) +def test_leaf_list_eq_19613(data: Any) -> None: + assert not pl.DataFrame([data]).equals(pl.DataFrame([[data]])) From cf3b47feee56db33bb28124f6acff65aa2c92118 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 3 Dec 2024 17:20:31 +0400 Subject: [PATCH 12/20] feat(python): Add lazy support for `pl.select` (#20091) --- py-polars/polars/functions/lazy.py | 47 +++++++++++++++++-- .../unit/functions/range/test_int_range.py | 6 +++ 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/py-polars/polars/functions/lazy.py b/py-polars/polars/functions/lazy.py index 30185cc6a586..69305597294f 100644 --- a/py-polars/polars/functions/lazy.py +++ b/py-polars/polars/functions/lazy.py @@ -1874,11 +1874,30 @@ def collect_all_async( return result -def select(*exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr) -> DataFrame: +@overload +def select( + *exprs: IntoExpr | Iterable[IntoExpr], + eager: Literal[True] = ..., + **named_exprs: IntoExpr, +) -> DataFrame: ... + + +@overload +def select( + *exprs: IntoExpr | Iterable[IntoExpr], + eager: Literal[False], + **named_exprs: IntoExpr, +) -> LazyFrame: ... + + +def select( + *exprs: IntoExpr | Iterable[IntoExpr], eager: bool = True, **named_exprs: IntoExpr +) -> DataFrame | LazyFrame: """ Run polars expressions without a context. - This is syntactic sugar for running `df.select` on an empty DataFrame. + This is syntactic sugar for running `df.select` on an empty DataFrame + (or LazyFrame if eager=False). Parameters ---------- @@ -1886,13 +1905,16 @@ def select(*exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr) -> Da Column(s) to select, specified as positional arguments. Accepts expression input. Strings are parsed as column names, other non-expression inputs are parsed as literals. + eager + Evaluate immediately and return a `DataFrame` (default); if set to `False`, + return a `LazyFrame` instead. **named_exprs Additional columns to select, specified as keyword arguments. The columns will be renamed to the keyword used. Returns ------- - DataFrame + DataFrame or LazyFrame Examples -------- @@ -1909,8 +1931,25 @@ def select(*exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr) -> Da │ 2 │ │ 1 │ └─────┘ + + >>> pl.select(pl.int_range(0, 100_000, 2).alias("n"), eager=False).filter( + ... pl.col("n") % 22_500 == 0 + ... ).collect() + shape: (5, 1) + ┌───────┐ + │ n │ + │ --- │ + │ i64 │ + ╞═══════╡ + │ 0 │ + │ 22500 │ + │ 45000 │ + │ 67500 │ + │ 90000 │ + └───────┘ """ - return pl.DataFrame().select(*exprs, **named_exprs) + empty_frame = pl.DataFrame() if eager else pl.LazyFrame() + return empty_frame.select(*exprs, **named_exprs) @overload diff --git a/py-polars/tests/unit/functions/range/test_int_range.py b/py-polars/tests/unit/functions/range/test_int_range.py index 78c85c5e9a98..f1d47c3000a6 100644 --- a/py-polars/tests/unit/functions/range/test_int_range.py +++ b/py-polars/tests/unit/functions/range/test_int_range.py @@ -70,6 +70,12 @@ def test_int_range_eager() -> None: assert_series_equal(result, expected) +def test_int_range_lazy() -> None: + lf = pl.select(n=pl.int_range(8, 0, -2), eager=False) + expected = pl.LazyFrame({"n": [8, 6, 4, 2]}) + assert_frame_equal(lf, expected) + + def test_int_range_schema() -> None: result = pl.LazyFrame().select(int=pl.int_range(-3, 3)) From cc05ff2d319159bf3d87a6c157806563031009e8 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Tue, 3 Dec 2024 14:34:11 +0100 Subject: [PATCH 13/20] fix: Check shape for `*_horizontal` functions (#20130) --- crates/polars-core/src/frame/mod.rs | 181 +------------- .../src/series/arithmetic/horizontal.rs | 223 ++++++++++++++++++ .../polars-core/src/series/arithmetic/mod.rs | 1 + .../polars-ops/src/series/ops/horizontal.rs | 51 ++-- .../src/dsl/function_expr/dispatch.rs | 27 ++- .../polars-plan/src/dsl/function_expr/mod.rs | 28 ++- .../src/dsl/function_expr/schema.rs | 4 +- .../src/dsl/functions/horizontal.rs | 8 +- crates/polars-python/src/dataframe/general.rs | 39 --- .../src/functions/aggregation.rs | 8 +- crates/polars-python/src/lazyframe/visit.rs | 2 +- .../src/lazyframe/visitor/expr_nodes.rs | 8 +- py-polars/polars/dataframe/frame.py | 8 +- .../functions/aggregation/horizontal.py | 18 +- .../tests/unit/functions/test_horizontal.py | 21 ++ 15 files changed, 359 insertions(+), 268 deletions(-) create mode 100644 crates/polars-core/src/series/arithmetic/horizontal.rs create mode 100644 py-polars/tests/unit/functions/test_horizontal.py diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index 2dc78741f977..09c9b18a03e7 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -1,6 +1,4 @@ //! DataFrame module. -#[cfg(feature = "zip_with")] -use std::borrow::Cow; use std::sync::OnceLock; use std::{mem, ops}; @@ -13,6 +11,7 @@ use crate::chunked_array::metadata::MetadataFlags; #[cfg(feature = "algorithm_group_by")] use crate::chunked_array::ops::unique::is_unique_helper; use crate::prelude::*; +use crate::series::arithmetic::horizontal as series_horizontal; #[cfg(feature = "row_hash")] use crate::utils::split_df; use crate::utils::{slice_offsets, try_get_supertype, Container, NoNull}; @@ -38,11 +37,8 @@ use polars_utils::pl_str::PlSmallStr; use serde::{Deserialize, Serialize}; use strum_macros::IntoStaticStr; -use crate::chunked_array::cast::CastOptions; #[cfg(feature = "row_hash")] use crate::hashing::_df_rows_to_hashes_threaded_vertical; -#[cfg(feature = "zip_with")] -use crate::prelude::min_max_binary::min_max_binary_columns; use crate::prelude::sort::{argsort_multiple_row_fmt, prepare_arg_sort}; use crate::series::IsSorted; use crate::POOL; @@ -2798,186 +2794,23 @@ impl DataFrame { /// Aggregate the column horizontally to their min values. #[cfg(feature = "zip_with")] pub fn min_horizontal(&self) -> PolarsResult> { - let min_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, true); - - match self.columns.len() { - 0 => Ok(None), - 1 => Ok(Some(self.columns[0].clone())), - 2 => min_fn(&self.columns[0], &self.columns[1]).map(Some), - _ => { - // the try_reduce_with is a bit slower in parallelism, - // but I don't think it matters here as we parallelize over columns, not over elements - POOL.install(|| { - self.columns - .par_iter() - .map(|s| Ok(Cow::Borrowed(s))) - .try_reduce_with(|l, r| min_fn(&l, &r).map(Cow::Owned)) - // we can unwrap the option, because we are certain there is a column - // we started this operation on 3 columns - .unwrap() - .map(|cow| Some(cow.into_owned())) - }) - }, - } + series_horizontal::min_horizontal(&self.columns) } /// Aggregate the column horizontally to their max values. #[cfg(feature = "zip_with")] pub fn max_horizontal(&self) -> PolarsResult> { - let max_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, false); - - match self.columns.len() { - 0 => Ok(None), - 1 => Ok(Some(self.columns[0].clone())), - 2 => max_fn(&self.columns[0], &self.columns[1]).map(Some), - _ => { - // the try_reduce_with is a bit slower in parallelism, - // but I don't think it matters here as we parallelize over columns, not over elements - POOL.install(|| { - self.columns - .par_iter() - .map(|s| Ok(Cow::Borrowed(s))) - .try_reduce_with(|l, r| max_fn(&l, &r).map(Cow::Owned)) - // we can unwrap the option, because we are certain there is a column - // we started this operation on 3 columns - .unwrap() - .map(|cow| Some(cow.into_owned())) - }) - }, - } + series_horizontal::max_horizontal(&self.columns) } /// Sum all values horizontally across columns. - pub fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult> { - let apply_null_strategy = - |s: Series, null_strategy: NullStrategy| -> PolarsResult { - if let NullStrategy::Ignore = null_strategy { - // if has nulls - if s.null_count() > 0 { - return s.fill_null(FillNullStrategy::Zero); - } - } - Ok(s) - }; - - let sum_fn = - |acc: Series, s: Series, null_strategy: NullStrategy| -> PolarsResult { - let acc: Series = apply_null_strategy(acc, null_strategy)?; - let s = apply_null_strategy(s, null_strategy)?; - // This will do owned arithmetic and can be mutable - std::ops::Add::add(acc, s) - }; - - let non_null_cols = self - .materialized_column_iter() - .filter(|x| x.dtype() != &DataType::Null) - .collect::>(); - - match non_null_cols.len() { - 0 => { - if self.columns.is_empty() { - Ok(None) - } else { - // all columns are null dtype, so result is null dtype - Ok(Some(self.columns[0].as_materialized_series().clone())) - } - }, - 1 => Ok(Some(apply_null_strategy( - if non_null_cols[0].dtype() == &DataType::Boolean { - non_null_cols[0].cast(&DataType::UInt32)? - } else { - non_null_cols[0].clone() - }, - null_strategy, - )?)), - 2 => sum_fn( - non_null_cols[0].clone(), - non_null_cols[1].clone(), - null_strategy, - ) - .map(Some), - _ => { - // the try_reduce_with is a bit slower in parallelism, - // but I don't think it matters here as we parallelize over columns, not over elements - let out = POOL.install(|| { - non_null_cols - .into_par_iter() - .cloned() - .map(Ok) - .try_reduce_with(|l, r| sum_fn(l, r, null_strategy)) - // We can unwrap because we started with at least 3 columns, so we always get a Some - .unwrap() - }); - out.map(Some) - }, - } + pub fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult> { + series_horizontal::sum_horizontal(&self.columns, null_strategy) } /// Compute the mean of all numeric values horizontally across columns. - pub fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult> { - let (numeric_columns, non_numeric_columns): (Vec<_>, Vec<_>) = - self.columns.iter().partition(|s| { - let dtype = s.dtype(); - dtype.is_numeric() || dtype.is_decimal() || dtype.is_bool() || dtype.is_null() - }); - - if !non_numeric_columns.is_empty() { - let col = non_numeric_columns.first().cloned(); - polars_bail!( - InvalidOperation: "'horizontal_mean' expects numeric expressions, found {:?} (dtype={})", - col.unwrap().name(), - col.unwrap().dtype(), - ); - } - let columns = numeric_columns.into_iter().cloned().collect::>(); - match columns.len() { - 0 => Ok(None), - 1 => Ok(Some(match columns[0].dtype() { - dt if dt != &DataType::Float32 && !dt.is_decimal() => columns[0] - .as_materialized_series() - .cast(&DataType::Float64)?, - _ => columns[0].as_materialized_series().clone(), - })), - _ => { - let numeric_df = unsafe { DataFrame::_new_no_checks_impl(self.height(), columns) }; - let sum = || numeric_df.sum_horizontal(null_strategy); - let null_count = || { - numeric_df - .par_materialized_column_iter() - .map(|s| { - s.is_null() - .cast_with_options(&DataType::UInt32, CastOptions::NonStrict) - }) - .reduce_with(|l, r| { - let l = l?; - let r = r?; - let result = std::ops::Add::add(&l, &r)?; - PolarsResult::Ok(result) - }) - // we can unwrap the option, because we are certain there is a column - // we started this operation on 2 columns - .unwrap() - }; - - let (sum, null_count) = POOL.install(|| rayon::join(sum, null_count)); - let sum = sum?; - let null_count = null_count?; - - // value lengths: len - null_count - let value_length: UInt32Chunked = - (numeric_df.width().sub(&null_count)).u32().unwrap().clone(); - - // make sure that we do not divide by zero - // by replacing with None - let value_length = value_length - .set(&value_length.equal(0), None)? - .into_series() - .cast(&DataType::Float64)?; - - sum.map(|sum| std::ops::Div::div(&sum, &value_length)) - .transpose() - }, - } + pub fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult> { + series_horizontal::mean_horizontal(&self.columns, null_strategy) } /// Pipe different functions/ closure operations that work on a DataFrame together. diff --git a/crates/polars-core/src/series/arithmetic/horizontal.rs b/crates/polars-core/src/series/arithmetic/horizontal.rs new file mode 100644 index 000000000000..94996b7302fe --- /dev/null +++ b/crates/polars-core/src/series/arithmetic/horizontal.rs @@ -0,0 +1,223 @@ +use std::borrow::Cow; + +use polars_error::{polars_bail, PolarsResult}; +use polars_utils::pl_str::PlSmallStr; +use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; + +#[cfg(feature = "zip_with")] +use super::min_max_binary::min_max_binary_columns; +use super::{ + ChunkCompareEq, ChunkSet, Column, DataType, FillNullStrategy, IntoColumn, Scalar, Series, + UInt32Chunked, +}; +use crate::chunked_array::cast::CastOptions; +use crate::frame::NullStrategy; +use crate::POOL; + +/// Aggregate the column horizontally to their min values. +/// +/// All columns need to be the same length or a scalar. +#[cfg(feature = "zip_with")] +pub fn min_horizontal(columns: &[Column]) -> PolarsResult> { + let min_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, true); + + match columns.len() { + 0 => Ok(None), + 1 => Ok(Some(columns[0].clone())), + 2 => min_fn(&columns[0], &columns[1]).map(Some), + _ => { + // the try_reduce_with is a bit slower in parallelism, + // but I don't think it matters here as we parallelize over columns, not over elements + POOL.install(|| { + columns + .par_iter() + .map(|s| Ok(Cow::Borrowed(s))) + .try_reduce_with(|l, r| min_fn(&l, &r).map(Cow::Owned)) + // we can unwrap the option, because we are certain there is a column + // we started this operation on 3 columns + .unwrap() + .map(|cow| Some(cow.into_owned())) + }) + }, + } +} + +/// Aggregate the column horizontally to their max values. +/// +/// All columns need to be the same length or a scalar. +#[cfg(feature = "zip_with")] +pub fn max_horizontal(columns: &[Column]) -> PolarsResult> { + let max_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, false); + + match columns.len() { + 0 => Ok(None), + 1 => Ok(Some(columns[0].clone())), + 2 => max_fn(&columns[0], &columns[1]).map(Some), + _ => { + // the try_reduce_with is a bit slower in parallelism, + // but I don't think it matters here as we parallelize over columns, not over elements + POOL.install(|| { + columns + .par_iter() + .map(|s| Ok(Cow::Borrowed(s))) + .try_reduce_with(|l, r| max_fn(&l, &r).map(Cow::Owned)) + // we can unwrap the option, because we are certain there is a column + // we started this operation on 3 columns + .unwrap() + .map(|cow| Some(cow.into_owned())) + }) + }, + } +} + +/// Sum all values horizontally across columns. +/// +/// All columns need to be the same length or a scalar. +pub fn sum_horizontal( + columns: &[Column], + null_strategy: NullStrategy, +) -> PolarsResult> { + let apply_null_strategy = |s: Series, null_strategy: NullStrategy| -> PolarsResult { + if let NullStrategy::Ignore = null_strategy { + // if has nulls + if s.null_count() > 0 { + return s.fill_null(FillNullStrategy::Zero); + } + } + Ok(s) + }; + + let sum_fn = |acc: Series, s: Series, null_strategy: NullStrategy| -> PolarsResult { + let acc: Series = apply_null_strategy(acc, null_strategy)?; + let s = apply_null_strategy(s, null_strategy)?; + // This will do owned arithmetic and can be mutable + std::ops::Add::add(acc, s) + }; + + // @scalar-opt + let non_null_cols = columns + .iter() + .filter(|x| x.dtype() != &DataType::Null) + .map(|c| c.as_materialized_series()) + .collect::>(); + + match non_null_cols.len() { + 0 => { + if columns.is_empty() { + Ok(None) + } else { + // all columns are null dtype, so result is null dtype + Ok(Some(columns[0].clone())) + } + }, + 1 => Ok(Some( + apply_null_strategy( + if non_null_cols[0].dtype() == &DataType::Boolean { + non_null_cols[0].cast(&DataType::UInt32)? + } else { + non_null_cols[0].clone() + }, + null_strategy, + )? + .into(), + )), + 2 => sum_fn( + non_null_cols[0].clone(), + non_null_cols[1].clone(), + null_strategy, + ) + .map(Column::from) + .map(Some), + _ => { + // the try_reduce_with is a bit slower in parallelism, + // but I don't think it matters here as we parallelize over columns, not over elements + let out = POOL.install(|| { + non_null_cols + .into_par_iter() + .cloned() + .map(Ok) + .try_reduce_with(|l, r| sum_fn(l, r, null_strategy)) + // We can unwrap because we started with at least 3 columns, so we always get a Some + .unwrap() + }); + out.map(Column::from).map(Some) + }, + } +} + +/// Compute the mean of all values horizontally across columns. +/// +/// All columns need to be the same length or a scalar. +pub fn mean_horizontal( + columns: &[Column], + null_strategy: NullStrategy, +) -> PolarsResult> { + let (numeric_columns, non_numeric_columns): (Vec<_>, Vec<_>) = columns.iter().partition(|s| { + let dtype = s.dtype(); + dtype.is_numeric() || dtype.is_decimal() || dtype.is_bool() || dtype.is_null() + }); + + if !non_numeric_columns.is_empty() { + let col = non_numeric_columns.first().cloned(); + polars_bail!( + InvalidOperation: "'horizontal_mean' expects numeric expressions, found {:?} (dtype={})", + col.unwrap().name(), + col.unwrap().dtype(), + ); + } + let columns = numeric_columns.into_iter().cloned().collect::>(); + match columns.len() { + 0 => Ok(None), + 1 => Ok(Some(match columns[0].dtype() { + dt if dt != &DataType::Float32 && !dt.is_decimal() => { + columns[0].cast(&DataType::Float64)? + }, + _ => columns[0].clone(), + })), + _ => { + let sum = || sum_horizontal(&columns, null_strategy); + let null_count = || { + columns + .par_iter() + .map(|c| { + c.is_null() + .into_column() + .cast_with_options(&DataType::UInt32, CastOptions::NonStrict) + }) + .reduce_with(|l, r| { + let l = l?; + let r = r?; + let result = std::ops::Add::add(&l, &r)?; + PolarsResult::Ok(result) + }) + // we can unwrap the option, because we are certain there is a column + // we started this operation on 2 columns + .unwrap() + }; + + let (sum, null_count) = POOL.install(|| rayon::join(sum, null_count)); + let sum = sum?; + let null_count = null_count?; + + // value lengths: len - null_count + let value_length: UInt32Chunked = (Column::new_scalar( + PlSmallStr::EMPTY, + Scalar::from(columns.len() as u32), + null_count.len(), + ) - null_count)? + .u32() + .unwrap() + .clone(); + + // make sure that we do not divide by zero + // by replacing with None + let value_length = value_length + .set(&value_length.equal(0), None)? + .into_column() + .cast(&DataType::Float64)?; + + sum.map(|sum| std::ops::Div::div(&sum, &value_length)) + .transpose() + }, + } +} diff --git a/crates/polars-core/src/series/arithmetic/mod.rs b/crates/polars-core/src/series/arithmetic/mod.rs index 8a4d317276c9..7cb1fd4674f2 100644 --- a/crates/polars-core/src/series/arithmetic/mod.rs +++ b/crates/polars-core/src/series/arithmetic/mod.rs @@ -1,5 +1,6 @@ mod bitops; mod borrowed; +pub mod horizontal; mod list; mod owned; diff --git a/crates/polars-ops/src/series/ops/horizontal.rs b/crates/polars-ops/src/series/ops/horizontal.rs index 663ac3664c8e..7e5cec8c1474 100644 --- a/crates/polars-ops/src/series/ops/horizontal.rs +++ b/crates/polars-ops/src/series/ops/horizontal.rs @@ -1,36 +1,43 @@ use polars_core::frame::NullStrategy; use polars_core::prelude::*; +fn validate_column_lengths(cs: &[Column]) -> PolarsResult<()> { + let mut length = 1; + for c in cs { + let len = c.len(); + if len != 1 && len != length { + if length == 1 { + length = len; + } else { + polars_bail!(ShapeMismatch: "cannot evaluate two Series of different lengths ({len} and {length})"); + } + } + } + Ok(()) +} + pub fn max_horizontal(s: &[Column]) -> PolarsResult> { - let df = - unsafe { DataFrame::_new_no_checks_impl(s.first().map_or(0, Column::len), Vec::from(s)) }; - df.max_horizontal() - .map(|s| s.map(Column::from)) - .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name().clone()))) + validate_column_lengths(s)?; + polars_core::series::arithmetic::horizontal::max_horizontal(s) + .map(|opt_c| opt_c.map(|res| res.with_name(s[0].name().clone()))) } pub fn min_horizontal(s: &[Column]) -> PolarsResult> { - let df = - unsafe { DataFrame::_new_no_checks_impl(s.first().map_or(0, Column::len), Vec::from(s)) }; - df.min_horizontal() - .map(|s| s.map(Column::from)) - .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name().clone()))) + validate_column_lengths(s)?; + polars_core::series::arithmetic::horizontal::min_horizontal(s) + .map(|opt_c| opt_c.map(|res| res.with_name(s[0].name().clone()))) } -pub fn sum_horizontal(s: &[Column]) -> PolarsResult> { - let df = - unsafe { DataFrame::_new_no_checks_impl(s.first().map_or(0, Column::len), Vec::from(s)) }; - df.sum_horizontal(NullStrategy::Ignore) - .map(|s| s.map(Column::from)) - .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name().clone()))) +pub fn sum_horizontal(s: &[Column], null_strategy: NullStrategy) -> PolarsResult> { + validate_column_lengths(s)?; + polars_core::series::arithmetic::horizontal::sum_horizontal(s, null_strategy) + .map(|opt_c| opt_c.map(|res| res.with_name(s[0].name().clone()))) } -pub fn mean_horizontal(s: &[Column]) -> PolarsResult> { - let df = - unsafe { DataFrame::_new_no_checks_impl(s.first().map_or(0, Column::len), Vec::from(s)) }; - df.mean_horizontal(NullStrategy::Ignore) - .map(|s| s.map(Column::from)) - .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name().clone()))) +pub fn mean_horizontal(s: &[Column], null_strategy: NullStrategy) -> PolarsResult> { + validate_column_lengths(s)?; + polars_core::series::arithmetic::horizontal::mean_horizontal(s, null_strategy) + .map(|opt_c| opt_c.map(|res| res.with_name(s[0].name().clone()))) } pub fn coalesce_columns(s: &[Column]) -> PolarsResult { diff --git a/crates/polars-plan/src/dsl/function_expr/dispatch.rs b/crates/polars-plan/src/dsl/function_expr/dispatch.rs index 895803cbcc63..7a338553e1e1 100644 --- a/crates/polars-plan/src/dsl/function_expr/dispatch.rs +++ b/crates/polars-plan/src/dsl/function_expr/dispatch.rs @@ -1,3 +1,5 @@ +use polars_core::frame::NullStrategy; + use super::*; pub(super) fn reverse(s: &Column) -> PolarsResult { @@ -103,12 +105,25 @@ pub(super) fn min_horizontal(s: &mut [Column]) -> PolarsResult> { polars_ops::prelude::min_horizontal(s) } -pub(super) fn sum_horizontal(s: &mut [Column]) -> PolarsResult> { - polars_ops::prelude::sum_horizontal(s) -} - -pub(super) fn mean_horizontal(s: &mut [Column]) -> PolarsResult> { - polars_ops::prelude::mean_horizontal(s) +pub(super) fn sum_horizontal(s: &mut [Column], ignore_nulls: bool) -> PolarsResult> { + let null_strategy = if ignore_nulls { + NullStrategy::Ignore + } else { + NullStrategy::Propagate + }; + polars_ops::prelude::sum_horizontal(s, null_strategy) +} + +pub(super) fn mean_horizontal( + s: &mut [Column], + ignore_nulls: bool, +) -> PolarsResult> { + let null_strategy = if ignore_nulls { + NullStrategy::Ignore + } else { + NullStrategy::Propagate + }; + polars_ops::prelude::mean_horizontal(s, null_strategy) } pub(super) fn drop_nulls(s: &Column) -> PolarsResult { diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 5813fa7a72cd..d694fcd0409e 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -333,8 +333,12 @@ pub enum FunctionExpr { }, MaxHorizontal, MinHorizontal, - SumHorizontal, - MeanHorizontal, + SumHorizontal { + ignore_nulls: bool, + }, + MeanHorizontal { + ignore_nulls: bool, + }, #[cfg(feature = "ewma")] EwmMean { options: EWMOptions, @@ -420,8 +424,16 @@ impl Hash for FunctionExpr { lib.hash(state); symbol.hash(state); }, - MaxHorizontal | MinHorizontal | SumHorizontal | MeanHorizontal | DropNans - | DropNulls | Reverse | ArgUnique | Shift | ShiftAndFill => {}, + MaxHorizontal + | MinHorizontal + | SumHorizontal { .. } + | MeanHorizontal { .. } + | DropNans + | DropNulls + | Reverse + | ArgUnique + | Shift + | ShiftAndFill => {}, #[cfg(feature = "mode")] Mode => {}, #[cfg(feature = "abs")] @@ -760,8 +772,8 @@ impl Display for FunctionExpr { ForwardFill { .. } => "forward_fill", MaxHorizontal => "max_horizontal", MinHorizontal => "min_horizontal", - SumHorizontal => "sum_horizontal", - MeanHorizontal => "mean_horizontal", + SumHorizontal { .. } => "sum_horizontal", + MeanHorizontal { .. } => "mean_horizontal", #[cfg(feature = "ewma")] EwmMean { .. } => "ewm_mean", #[cfg(feature = "ewma_by")] @@ -1170,8 +1182,8 @@ impl From for SpecialEq> { ForwardFill { limit } => map!(dispatch::forward_fill, limit), MaxHorizontal => wrap!(dispatch::max_horizontal), MinHorizontal => wrap!(dispatch::min_horizontal), - SumHorizontal => wrap!(dispatch::sum_horizontal), - MeanHorizontal => wrap!(dispatch::mean_horizontal), + SumHorizontal { ignore_nulls } => wrap!(dispatch::sum_horizontal, ignore_nulls), + MeanHorizontal { ignore_nulls } => wrap!(dispatch::mean_horizontal, ignore_nulls), #[cfg(feature = "ewma")] EwmMean { options } => map!(ewm::ewm_mean, options), #[cfg(feature = "ewma_by")] diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 8ac54c172993..018a3b0207d1 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -329,14 +329,14 @@ impl FunctionExpr { ForwardFill { .. } => mapper.with_same_dtype(), MaxHorizontal => mapper.map_to_supertype(), MinHorizontal => mapper.map_to_supertype(), - SumHorizontal => { + SumHorizontal { .. } => { if mapper.fields[0].dtype() == &DataType::Boolean { mapper.with_dtype(DataType::UInt32) } else { mapper.map_to_supertype() } }, - MeanHorizontal => mapper.map_to_float_dtype(), + MeanHorizontal { .. } => mapper.map_to_float_dtype(), #[cfg(feature = "ewma")] EwmMean { .. } => mapper.map_to_float_dtype(), #[cfg(feature = "ewma_by")] diff --git a/crates/polars-plan/src/dsl/functions/horizontal.rs b/crates/polars-plan/src/dsl/functions/horizontal.rs index 26b6209a720e..f81571c6ff32 100644 --- a/crates/polars-plan/src/dsl/functions/horizontal.rs +++ b/crates/polars-plan/src/dsl/functions/horizontal.rs @@ -274,13 +274,13 @@ pub fn min_horizontal>(exprs: E) -> PolarsResult { } /// Sum all values horizontally across columns. -pub fn sum_horizontal>(exprs: E) -> PolarsResult { +pub fn sum_horizontal>(exprs: E, ignore_nulls: bool) -> PolarsResult { let exprs = exprs.as_ref().to_vec(); polars_ensure!(!exprs.is_empty(), ComputeError: "cannot return empty fold because the number of output rows is unknown"); Ok(Expr::Function { input: exprs, - function: FunctionExpr::SumHorizontal, + function: FunctionExpr::SumHorizontal { ignore_nulls }, options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, flags: FunctionFlags::default() @@ -292,13 +292,13 @@ pub fn sum_horizontal>(exprs: E) -> PolarsResult { } /// Compute the mean of all values horizontally across columns. -pub fn mean_horizontal>(exprs: E) -> PolarsResult { +pub fn mean_horizontal>(exprs: E, ignore_nulls: bool) -> PolarsResult { let exprs = exprs.as_ref().to_vec(); polars_ensure!(!exprs.is_empty(), ComputeError: "cannot return empty fold because the number of output rows is unknown"); Ok(Expr::Function { input: exprs, - function: FunctionExpr::MeanHorizontal, + function: FunctionExpr::MeanHorizontal { ignore_nulls }, options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, flags: FunctionFlags::default() diff --git a/crates/polars-python/src/dataframe/general.rs b/crates/polars-python/src/dataframe/general.rs index 1c2f94368275..80741dea3700 100644 --- a/crates/polars-python/src/dataframe/general.rs +++ b/crates/polars-python/src/dataframe/general.rs @@ -3,7 +3,6 @@ use std::mem::ManuallyDrop; use either::Either; use polars::export::arrow::bitmap::MutableBitmap; use polars::prelude::*; -use polars_core::frame::*; #[cfg(feature = "pivot")] use polars_lazy::frame::pivot::{pivot, pivot_stable}; use polars_row::RowEncodingOptions; @@ -514,44 +513,6 @@ impl PyDataFrame { self.df.clone().lazy().into() } - pub fn max_horizontal(&self, py: Python) -> PyResult> { - let s = py - .allow_threads(|| self.df.max_horizontal()) - .map_err(PyPolarsErr::from)?; - Ok(s.map(|s| s.take_materialized_series().into())) - } - - pub fn min_horizontal(&self, py: Python) -> PyResult> { - let s = py - .allow_threads(|| self.df.min_horizontal()) - .map_err(PyPolarsErr::from)?; - Ok(s.map(|s| s.take_materialized_series().into())) - } - - pub fn sum_horizontal(&self, py: Python, ignore_nulls: bool) -> PyResult> { - let null_strategy = if ignore_nulls { - NullStrategy::Ignore - } else { - NullStrategy::Propagate - }; - let s = py - .allow_threads(|| self.df.sum_horizontal(null_strategy)) - .map_err(PyPolarsErr::from)?; - Ok(s.map(|s| s.into())) - } - - pub fn mean_horizontal(&self, py: Python, ignore_nulls: bool) -> PyResult> { - let null_strategy = if ignore_nulls { - NullStrategy::Ignore - } else { - NullStrategy::Propagate - }; - let s = py - .allow_threads(|| self.df.mean_horizontal(null_strategy)) - .map_err(PyPolarsErr::from)?; - Ok(s.map(|s| s.into())) - } - #[pyo3(signature = (columns, separator, drop_first=false))] pub fn to_dummies( &self, diff --git a/crates/polars-python/src/functions/aggregation.rs b/crates/polars-python/src/functions/aggregation.rs index 1d27ae8fee69..03c7f802f7ee 100644 --- a/crates/polars-python/src/functions/aggregation.rs +++ b/crates/polars-python/src/functions/aggregation.rs @@ -34,15 +34,15 @@ pub fn min_horizontal(exprs: Vec) -> PyResult { } #[pyfunction] -pub fn sum_horizontal(exprs: Vec) -> PyResult { +pub fn sum_horizontal(exprs: Vec, ignore_nulls: bool) -> PyResult { let exprs = exprs.to_exprs(); - let e = dsl::sum_horizontal(exprs).map_err(PyPolarsErr::from)?; + let e = dsl::sum_horizontal(exprs, ignore_nulls).map_err(PyPolarsErr::from)?; Ok(e.into()) } #[pyfunction] -pub fn mean_horizontal(exprs: Vec) -> PyResult { +pub fn mean_horizontal(exprs: Vec, ignore_nulls: bool) -> PyResult { let exprs = exprs.to_exprs(); - let e = dsl::mean_horizontal(exprs).map_err(PyPolarsErr::from)?; + let e = dsl::mean_horizontal(exprs, ignore_nulls).map_err(PyPolarsErr::from)?; Ok(e.into()) } diff --git a/crates/polars-python/src/lazyframe/visit.rs b/crates/polars-python/src/lazyframe/visit.rs index b698f68a47c7..27633e401301 100644 --- a/crates/polars-python/src/lazyframe/visit.rs +++ b/crates/polars-python/src/lazyframe/visit.rs @@ -57,7 +57,7 @@ impl NodeTraverser { // Increment major on breaking changes to the IR (e.g. renaming // fields, reordering tuples), minor on backwards compatible // changes (e.g. exposing a new expression node). - const VERSION: Version = (3, 2); + const VERSION: Version = (4, 2); pub fn new(root: Node, lp_arena: Arena, expr_arena: Arena) -> Self { Self { diff --git a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs index c5cda028f74b..604398d78857 100644 --- a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs @@ -1326,9 +1326,13 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { }, FunctionExpr::BackwardFill { limit } => ("backward_fill", limit).to_object(py), FunctionExpr::ForwardFill { limit } => ("forward_fill", limit).to_object(py), - FunctionExpr::SumHorizontal => ("sum_horizontal",).to_object(py), + FunctionExpr::SumHorizontal { ignore_nulls } => { + ("sum_horizontal", ignore_nulls).to_object(py) + }, FunctionExpr::MaxHorizontal => ("max_horizontal",).to_object(py), - FunctionExpr::MeanHorizontal => ("mean_horizontal",).to_object(py), + FunctionExpr::MeanHorizontal { ignore_nulls } => { + ("mean_horizontal", ignore_nulls).to_object(py) + }, FunctionExpr::MinHorizontal => ("min_horizontal",).to_object(py), FunctionExpr::EwmMean { options: _ } => { return Err(PyNotImplementedError::new_err("ewm mean")) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 92b10e81ca45..49e1053e375d 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -9528,7 +9528,9 @@ def sum_horizontal(self, *, ignore_nulls: bool = True) -> Series: 9.0 ] """ - return wrap_s(self._df.sum_horizontal(ignore_nulls)).alias("sum") + return self.select( + sum=F.sum_horizontal(F.all(), ignore_nulls=ignore_nulls) + ).to_series() def mean(self) -> DataFrame: """ @@ -9588,7 +9590,9 @@ def mean_horizontal(self, *, ignore_nulls: bool = True) -> Series: 4.5 ] """ - return wrap_s(self._df.mean_horizontal(ignore_nulls)).alias("mean") + return self.select( + mean=F.mean_horizontal(F.all(), ignore_nulls=ignore_nulls) + ).to_series() def std(self, ddof: int = 1) -> DataFrame: """ diff --git a/py-polars/polars/functions/aggregation/horizontal.py b/py-polars/polars/functions/aggregation/horizontal.py index 5406a77d287d..ba72d42b609a 100644 --- a/py-polars/polars/functions/aggregation/horizontal.py +++ b/py-polars/polars/functions/aggregation/horizontal.py @@ -178,7 +178,9 @@ def min_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return wrap_expr(plr.min_horizontal(pyexprs)) -def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: +def sum_horizontal( + *exprs: IntoExpr | Iterable[IntoExpr], ignore_nulls: bool = True +) -> Expr: """ Sum all values horizontally across columns. @@ -187,6 +189,9 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: *exprs Column(s) to use in the aggregation. Accepts expression input. Strings are parsed as column names, other non-expression inputs are parsed as literals. + ignore_nulls + Ignore null values (default). + If set to `False`, any null value in the input will lead to a null output. Examples -------- @@ -210,10 +215,12 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: └─────┴──────┴─────┴─────┘ """ pyexprs = parse_into_list_of_expressions(*exprs) - return wrap_expr(plr.sum_horizontal(pyexprs)) + return wrap_expr(plr.sum_horizontal(pyexprs, ignore_nulls)) -def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: +def mean_horizontal( + *exprs: IntoExpr | Iterable[IntoExpr], ignore_nulls: bool = True +) -> Expr: """ Compute the mean of all values horizontally across columns. @@ -222,6 +229,9 @@ def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: *exprs Column(s) to use in the aggregation. Accepts expression input. Strings are parsed as column names, other non-expression inputs are parsed as literals. + ignore_nulls + Ignore null values (default). + If set to `False`, any null value in the input will lead to a null output. Examples -------- @@ -245,7 +255,7 @@ def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: └─────┴──────┴─────┴──────┘ """ pyexprs = parse_into_list_of_expressions(*exprs) - return wrap_expr(plr.mean_horizontal(pyexprs)) + return wrap_expr(plr.mean_horizontal(pyexprs, ignore_nulls)) def cum_sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: diff --git a/py-polars/tests/unit/functions/test_horizontal.py b/py-polars/tests/unit/functions/test_horizontal.py new file mode 100644 index 000000000000..688e2bfcf58a --- /dev/null +++ b/py-polars/tests/unit/functions/test_horizontal.py @@ -0,0 +1,21 @@ +import pytest + +import polars as pl + + +@pytest.mark.parametrize( + "f", + [ + "min", + "max", + "sum", + "mean", + ], +) +def test_shape_mismatch_19336(f: str) -> None: + a = pl.Series([1, 2, 3]) + b = pl.Series([1, 2]) + fn = getattr(pl, f"{f}_horizontal") + + with pytest.raises(pl.exceptions.ShapeError): + pl.select((fn)(a, b)) From bf8f64c70db9a371b53f938e5d75a28946e65c92 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Tue, 3 Dec 2024 15:07:01 +0100 Subject: [PATCH 14/20] fix: Handle slice pushdown in PythonUDF GroupBy (#20132) --- crates/polars-core/src/frame/group_by/mod.rs | 11 ++++++++ .../src/executors/group_by.rs | 2 +- .../tests/unit/operations/test_group_by.py | 28 +++++++++++++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index 7d1d9e761504..2a5aa2c40e45 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -841,6 +841,17 @@ impl<'df> GroupBy<'df> { df.as_single_chunk_par(); Ok(df) } + + pub fn sliced(mut self, slice: Option<(i64, usize)>) -> Self { + match slice { + None => self, + Some((offset, length)) => { + self.groups = (*self.groups.slice(offset, length)).clone(); + self.selected_keys = self.keys_sliced(slice); + self + }, + } + } } unsafe fn take_df(df: &DataFrame, g: GroupsIndicator) -> DataFrame { diff --git a/crates/polars-mem-engine/src/executors/group_by.rs b/crates/polars-mem-engine/src/executors/group_by.rs index 437b7fb574aa..f7a501424ed9 100644 --- a/crates/polars-mem-engine/src/executors/group_by.rs +++ b/crates/polars-mem-engine/src/executors/group_by.rs @@ -67,7 +67,7 @@ pub(super) fn group_by_helper( let gb = df.group_by_with_series(keys, true, maintain_order)?; if let Some(f) = apply { - return gb.apply(move |df| f.call_udf(df)); + return gb.sliced(slice).apply(move |df| f.call_udf(df)); } let mut groups = gb.get_groups(); diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index cff43b43274c..645b978214f0 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -1153,3 +1153,31 @@ def test_group_by_agg_19173() -> None: out = df.head(0).group_by("g").agg((pl.col.x - pl.col.x.sum() * pl.col.x) ** 2) assert out.to_dict(as_series=False) == {"g": [], "x": []} assert out.schema == pl.Schema([("g", pl.Int64), ("x", pl.List(pl.Float64))]) + + +def test_group_by_map_groups_slice_pushdown_20002() -> None: + schema = { + "a": pl.Int8, + "b": pl.UInt8, + } + + df = ( + pl.LazyFrame( + data={"a": [1, 2, 3, 4, 5], "b": [90, 80, 70, 60, 50]}, + schema=schema, + ) + .group_by("a", maintain_order=True) + .map_groups(lambda df: df * 2.0, schema=schema) + .head(3) + .collect() + ) + + assert_frame_equal( + df, + pl.DataFrame( + { + "a": [2.0, 4.0, 6.0], + "b": [180.0, 160.0, 140.0], + } + ), + ) From 9b91418a1d03a8312f94771ee294835928271bb9 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Tue, 3 Dec 2024 16:47:52 +0100 Subject: [PATCH 15/20] fix: Implement `arg_sort` for Null series (#20135) --- crates/polars-core/src/series/implementations/null.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/crates/polars-core/src/series/implementations/null.rs b/crates/polars-core/src/series/implementations/null.rs index 72b4e14114eb..6e88f2e26c3c 100644 --- a/crates/polars-core/src/series/implementations/null.rs +++ b/crates/polars-core/src/series/implementations/null.rs @@ -263,6 +263,10 @@ impl SeriesTrait for NullChunked { Ok(self.clone().into_series()) } + fn arg_sort(&self, _options: SortOptions) -> IdxCa { + IdxCa::from_vec(self.name().clone(), (0..self.len() as IdxSize).collect()) + } + fn is_null(&self) -> BooleanChunked { BooleanChunked::full(self.name().clone(), true, self.len()) } From d629cae8df6d43f6d8781c7414b5c8bcfe492010 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Tue, 3 Dec 2024 16:54:13 +0100 Subject: [PATCH 16/20] chore: Remove useless SeriesTrait::get implementations (#20136) --- .../polars-core/src/series/implementations/array.rs | 4 ---- .../polars-core/src/series/implementations/binary.rs | 4 ---- .../src/series/implementations/binary_offset.rs | 4 ---- .../polars-core/src/series/implementations/boolean.rs | 4 ---- .../src/series/implementations/categorical.rs | 4 ---- crates/polars-core/src/series/implementations/date.rs | 4 ---- .../src/series/implementations/datetime.rs | 4 ---- .../polars-core/src/series/implementations/decimal.rs | 4 ---- .../src/series/implementations/duration.rs | 4 ---- .../polars-core/src/series/implementations/floats.rs | 4 ---- crates/polars-core/src/series/implementations/list.rs | 4 ---- crates/polars-core/src/series/implementations/mod.rs | 4 ---- crates/polars-core/src/series/implementations/null.rs | 5 ----- .../polars-core/src/series/implementations/string.rs | 4 ---- .../polars-core/src/series/implementations/struct_.rs | 4 ---- crates/polars-core/src/series/implementations/time.rs | 4 ---- crates/polars-core/src/series/series_trait.rs | 11 +++++++---- 17 files changed, 7 insertions(+), 69 deletions(-) diff --git a/crates/polars-core/src/series/implementations/array.rs b/crates/polars-core/src/series/implementations/array.rs index 51bd084cd46d..b1929167b515 100644 --- a/crates/polars-core/src/series/implementations/array.rs +++ b/crates/polars-core/src/series/implementations/array.rs @@ -145,10 +145,6 @@ impl SeriesTrait for SeriesWrap { self.0.cast_with_options(dtype, options) } - fn get(&self, index: usize) -> PolarsResult { - self.0.get_any_value(index) - } - #[inline] unsafe fn get_unchecked(&self, index: usize) -> AnyValue { self.0.get_any_value_unchecked(index) diff --git a/crates/polars-core/src/series/implementations/binary.rs b/crates/polars-core/src/series/implementations/binary.rs index b0fe79845a0d..fac7a5086c5d 100644 --- a/crates/polars-core/src/series/implementations/binary.rs +++ b/crates/polars-core/src/series/implementations/binary.rs @@ -174,10 +174,6 @@ impl SeriesTrait for SeriesWrap { self.0.cast_with_options(dtype, options) } - fn get(&self, index: usize) -> PolarsResult { - self.0.get_any_value(index) - } - #[inline] unsafe fn get_unchecked(&self, index: usize) -> AnyValue { self.0.get_any_value_unchecked(index) diff --git a/crates/polars-core/src/series/implementations/binary_offset.rs b/crates/polars-core/src/series/implementations/binary_offset.rs index 481b5c5bf47e..07844d96a994 100644 --- a/crates/polars-core/src/series/implementations/binary_offset.rs +++ b/crates/polars-core/src/series/implementations/binary_offset.rs @@ -146,10 +146,6 @@ impl SeriesTrait for SeriesWrap { self.0.cast_with_options(dtype, options) } - fn get(&self, index: usize) -> PolarsResult { - self.0.get_any_value(index) - } - #[inline] unsafe fn get_unchecked(&self, index: usize) -> AnyValue { self.0.get_any_value_unchecked(index) diff --git a/crates/polars-core/src/series/implementations/boolean.rs b/crates/polars-core/src/series/implementations/boolean.rs index b4cd48295c4c..2c8becc2edab 100644 --- a/crates/polars-core/src/series/implementations/boolean.rs +++ b/crates/polars-core/src/series/implementations/boolean.rs @@ -208,10 +208,6 @@ impl SeriesTrait for SeriesWrap { self.0.cast_with_options(dtype, options) } - fn get(&self, index: usize) -> PolarsResult { - self.0.get_any_value(index) - } - #[inline] unsafe fn get_unchecked(&self, index: usize) -> AnyValue { self.0.get_any_value_unchecked(index) diff --git a/crates/polars-core/src/series/implementations/categorical.rs b/crates/polars-core/src/series/implementations/categorical.rs index 7c45baa5c054..00a32d0dea83 100644 --- a/crates/polars-core/src/series/implementations/categorical.rs +++ b/crates/polars-core/src/series/implementations/categorical.rs @@ -223,10 +223,6 @@ impl SeriesTrait for SeriesWrap { self.0.cast_with_options(dtype, options) } - fn get(&self, index: usize) -> PolarsResult { - self.0.get_any_value(index) - } - #[inline] unsafe fn get_unchecked(&self, index: usize) -> AnyValue { self.0.get_any_value_unchecked(index) diff --git a/crates/polars-core/src/series/implementations/date.rs b/crates/polars-core/src/series/implementations/date.rs index a2ef6ed0788c..8aa8c43f608a 100644 --- a/crates/polars-core/src/series/implementations/date.rs +++ b/crates/polars-core/src/series/implementations/date.rs @@ -266,10 +266,6 @@ impl SeriesTrait for SeriesWrap { } } - fn get(&self, index: usize) -> PolarsResult { - self.0.get_any_value(index) - } - #[inline] unsafe fn get_unchecked(&self, index: usize) -> AnyValue { self.0.get_any_value_unchecked(index) diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index 7f0d575bd916..109275dbcd34 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -260,10 +260,6 @@ impl SeriesTrait for SeriesWrap { } } - fn get(&self, index: usize) -> PolarsResult { - self.0.get_any_value(index) - } - #[inline] unsafe fn get_unchecked(&self, index: usize) -> AnyValue { self.0.get_any_value_unchecked(index) diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index 612505057eca..bd4049efb42f 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -294,10 +294,6 @@ impl SeriesTrait for SeriesWrap { self.0.cast_with_options(dtype, cast_options) } - fn get(&self, index: usize) -> PolarsResult { - self.0.get_any_value(index) - } - #[inline] unsafe fn get_unchecked(&self, index: usize) -> AnyValue { self.0.get_any_value_unchecked(index) diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index 803ca813aa1c..555e630ac3fb 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -380,10 +380,6 @@ impl SeriesTrait for SeriesWrap { self.0.cast_with_options(dtype, cast_options) } - fn get(&self, index: usize) -> PolarsResult { - self.0.get_any_value(index) - } - #[inline] unsafe fn get_unchecked(&self, index: usize) -> AnyValue { self.0.get_any_value_unchecked(index) diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index 85c0d87cf0f1..780f8130ed72 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -254,10 +254,6 @@ macro_rules! impl_dyn_series { self.0.cast_with_options(dtype, cast_options) } - fn get(&self, index: usize) -> PolarsResult { - self.0.get_any_value(index) - } - #[inline] unsafe fn get_unchecked(&self, index: usize) -> AnyValue { self.0.get_any_value_unchecked(index) diff --git a/crates/polars-core/src/series/implementations/list.rs b/crates/polars-core/src/series/implementations/list.rs index 5e5b4a95d5e2..529a5f9e98d5 100644 --- a/crates/polars-core/src/series/implementations/list.rs +++ b/crates/polars-core/src/series/implementations/list.rs @@ -148,10 +148,6 @@ impl SeriesTrait for SeriesWrap { self.0.cast_with_options(dtype, cast_options) } - fn get(&self, index: usize) -> PolarsResult { - self.0.get_any_value(index) - } - #[inline] unsafe fn get_unchecked(&self, index: usize) -> AnyValue { self.0.get_any_value_unchecked(index) diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index d4b9626d2bfc..fb1a093f130d 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -326,10 +326,6 @@ macro_rules! impl_dyn_series { self.0.cast_with_options(dtype, options) } - fn get(&self, index: usize) -> PolarsResult { - self.0.get_any_value(index) - } - #[inline] unsafe fn get_unchecked(&self, index: usize) -> AnyValue { self.0.get_any_value_unchecked(index) diff --git a/crates/polars-core/src/series/implementations/null.rs b/crates/polars-core/src/series/implementations/null.rs index 6e88f2e26c3c..9fec870ee7e2 100644 --- a/crates/polars-core/src/series/implementations/null.rs +++ b/crates/polars-core/src/series/implementations/null.rs @@ -222,11 +222,6 @@ impl SeriesTrait for NullChunked { NullChunked::new(self.name.clone(), length).into_series() } - fn get(&self, index: usize) -> PolarsResult { - polars_ensure!(index < self.len(), oob = index, self.len()); - Ok(AnyValue::Null) - } - unsafe fn get_unchecked(&self, _index: usize) -> AnyValue { AnyValue::Null } diff --git a/crates/polars-core/src/series/implementations/string.rs b/crates/polars-core/src/series/implementations/string.rs index 2547d9bd237c..1d68298681a5 100644 --- a/crates/polars-core/src/series/implementations/string.rs +++ b/crates/polars-core/src/series/implementations/string.rs @@ -179,10 +179,6 @@ impl SeriesTrait for SeriesWrap { self.0.cast_with_options(dtype, cast_options) } - fn get(&self, index: usize) -> PolarsResult { - self.0.get_any_value(index) - } - #[inline] unsafe fn get_unchecked(&self, index: usize) -> AnyValue { self.0.get_any_value_unchecked(index) diff --git a/crates/polars-core/src/series/implementations/struct_.rs b/crates/polars-core/src/series/implementations/struct_.rs index d40e53d1a01e..d9e75399a988 100644 --- a/crates/polars-core/src/series/implementations/struct_.rs +++ b/crates/polars-core/src/series/implementations/struct_.rs @@ -150,10 +150,6 @@ impl SeriesTrait for SeriesWrap { self.0.cast_with_options(dtype, cast_options) } - fn get(&self, index: usize) -> PolarsResult { - self.0.get_any_value(index) - } - unsafe fn get_unchecked(&self, index: usize) -> AnyValue { self.0.get_any_value_unchecked(index) } diff --git a/crates/polars-core/src/series/implementations/time.rs b/crates/polars-core/src/series/implementations/time.rs index 870efc27de7e..99696c8e84d6 100644 --- a/crates/polars-core/src/series/implementations/time.rs +++ b/crates/polars-core/src/series/implementations/time.rs @@ -228,10 +228,6 @@ impl SeriesTrait for SeriesWrap { } } - fn get(&self, index: usize) -> PolarsResult { - self.0.get_any_value(index) - } - #[inline] unsafe fn get_unchecked(&self, index: usize) -> AnyValue { self.0.get_any_value_unchecked(index) diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index c77a9de0f7ad..ed420ec0d0b1 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -398,7 +398,12 @@ pub trait SeriesTrait: /// Get a single value by index. Don't use this operation for loops as a runtime cast is /// needed for every iteration. - fn get(&self, _index: usize) -> PolarsResult; + fn get(&self, index: usize) -> PolarsResult { + polars_ensure!(index < self.len(), oob = index, self.len()); + // SAFETY: Just did bounds check + let value = unsafe { self.get_unchecked(index) }; + Ok(value) + } /// Get a single value by index. Don't use this operation for loops as a runtime cast is /// needed for every iteration. @@ -407,9 +412,7 @@ pub trait SeriesTrait: /// /// # Safety /// Does not do any bounds checking - unsafe fn get_unchecked(&self, _index: usize) -> AnyValue { - invalid_operation_panic!(get_unchecked, self) - } + unsafe fn get_unchecked(&self, _index: usize) -> AnyValue; fn sort_with(&self, _options: SortOptions) -> PolarsResult { polars_bail!(opq = sort_with, self._dtype()); From 3faddcb2506e830d59e6be5834a596b617217ef3 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Tue, 3 Dec 2024 16:54:55 +0100 Subject: [PATCH 17/20] chore: Move horizontal methods to polars-ops (#20134) --- .../src/chunked_array/ops/min_max_binary.rs | 81 ---- .../polars-core/src/chunked_array/ops/mod.rs | 2 - crates/polars-core/src/frame/mod.rs | 68 ---- .../src/series/arithmetic/borrowed.rs | 2 +- .../src/series/arithmetic/horizontal.rs | 223 ----------- .../polars-core/src/series/arithmetic/mod.rs | 1 - .../polars-ops/src/series/ops/horizontal.rs | 374 +++++++++++++++++- .../src/dsl/function_expr/dispatch.rs | 2 +- crates/polars-python/src/conversion/mod.rs | 1 - 9 files changed, 359 insertions(+), 395 deletions(-) delete mode 100644 crates/polars-core/src/chunked_array/ops/min_max_binary.rs delete mode 100644 crates/polars-core/src/series/arithmetic/horizontal.rs diff --git a/crates/polars-core/src/chunked_array/ops/min_max_binary.rs b/crates/polars-core/src/chunked_array/ops/min_max_binary.rs deleted file mode 100644 index 28e7c491095b..000000000000 --- a/crates/polars-core/src/chunked_array/ops/min_max_binary.rs +++ /dev/null @@ -1,81 +0,0 @@ -use crate::prelude::*; -use crate::series::arithmetic::coerce_lhs_rhs; - -fn min_binary(left: &ChunkedArray, right: &ChunkedArray) -> ChunkedArray -where - T: PolarsNumericType, - T::Native: PartialOrd, -{ - let op = |l: T::Native, r: T::Native| { - if l < r { - l - } else { - r - } - }; - arity::binary_elementwise_values(left, right, op) -} - -fn max_binary(left: &ChunkedArray, right: &ChunkedArray) -> ChunkedArray -where - T: PolarsNumericType, - T::Native: PartialOrd, -{ - let op = |l: T::Native, r: T::Native| { - if l > r { - l - } else { - r - } - }; - arity::binary_elementwise_values(left, right, op) -} - -pub(crate) fn min_max_binary_columns( - left: &Column, - right: &Column, - min: bool, -) -> PolarsResult { - if left.dtype().to_physical().is_numeric() - && left.null_count() == 0 - && right.null_count() == 0 - && left.len() == right.len() - { - match (left, right) { - (Column::Series(left), Column::Series(right)) => { - let (lhs, rhs) = coerce_lhs_rhs(left, right)?; - let logical = lhs.dtype(); - let lhs = lhs.to_physical_repr(); - let rhs = rhs.to_physical_repr(); - - with_match_physical_numeric_polars_type!(lhs.dtype(), |$T| { - let a: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); - let b: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); - - if min { - min_binary(a, b).into_series().cast(logical) - } else { - max_binary(a, b).into_series().cast(logical) - } - }) - .map(Column::from) - }, - _ => { - let mask = if min { - left.lt(right)? - } else { - left.gt(right)? - }; - - left.zip_with(&mask, right) - }, - } - } else { - let mask = if min { - left.lt(right)? & left.is_not_null() | right.is_null() - } else { - left.gt(right)? & left.is_not_null() | right.is_null() - }; - left.zip_with(&mask, right) - } -} diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index c0daaa72bdf6..d598199bfecb 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -28,8 +28,6 @@ pub mod float_sorted_arg_max; mod for_each; pub mod full; pub mod gather; -#[cfg(feature = "zip_with")] -pub(crate) mod min_max_binary; pub(crate) mod nulls; mod reverse; #[cfg(feature = "rolling_window")] diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index 09c9b18a03e7..301dd0a53517 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -11,7 +11,6 @@ use crate::chunked_array::metadata::MetadataFlags; #[cfg(feature = "algorithm_group_by")] use crate::chunked_array::ops::unique::is_unique_helper; use crate::prelude::*; -use crate::series::arithmetic::horizontal as series_horizontal; #[cfg(feature = "row_hash")] use crate::utils::split_df; use crate::utils::{slice_offsets, try_get_supertype, Container, NoNull}; @@ -43,12 +42,6 @@ use crate::prelude::sort::{argsort_multiple_row_fmt, prepare_arg_sort}; use crate::series::IsSorted; use crate::POOL; -#[derive(Copy, Clone, Debug)] -pub enum NullStrategy { - Ignore, - Propagate, -} - #[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[strum(serialize_all = "snake_case")] @@ -2791,28 +2784,6 @@ impl DataFrame { Ok(unsafe { DataFrame::new_no_checks(self.height(), col) }) } - /// Aggregate the column horizontally to their min values. - #[cfg(feature = "zip_with")] - pub fn min_horizontal(&self) -> PolarsResult> { - series_horizontal::min_horizontal(&self.columns) - } - - /// Aggregate the column horizontally to their max values. - #[cfg(feature = "zip_with")] - pub fn max_horizontal(&self) -> PolarsResult> { - series_horizontal::max_horizontal(&self.columns) - } - - /// Sum all values horizontally across columns. - pub fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult> { - series_horizontal::sum_horizontal(&self.columns, null_strategy) - } - - /// Compute the mean of all numeric values horizontally across columns. - pub fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult> { - series_horizontal::mean_horizontal(&self.columns, null_strategy) - } - /// Pipe different functions/ closure operations that work on a DataFrame together. pub fn pipe(self, f: F) -> PolarsResult where @@ -3515,45 +3486,6 @@ mod test { assert_eq!(df.height, 6) } - #[test] - #[cfg(feature = "zip_with")] - #[cfg_attr(miri, ignore)] - fn test_horizontal_agg() { - let a = Column::new("a".into(), [1, 2, 6]); - let b = Column::new("b".into(), [Some(1), None, None]); - let c = Column::new("c".into(), [Some(4), None, Some(3)]); - - let df = DataFrame::new(vec![a, b, c]).unwrap(); - assert_eq!( - Vec::from( - df.mean_horizontal(NullStrategy::Ignore) - .unwrap() - .unwrap() - .f64() - .unwrap() - ), - &[Some(2.0), Some(2.0), Some(4.5)] - ); - assert_eq!( - Vec::from( - df.sum_horizontal(NullStrategy::Ignore) - .unwrap() - .unwrap() - .i32() - .unwrap() - ), - &[Some(6), Some(2), Some(9)] - ); - assert_eq!( - Vec::from(df.min_horizontal().unwrap().unwrap().i32().unwrap()), - &[Some(1), Some(2), Some(3)] - ); - assert_eq!( - Vec::from(df.max_horizontal().unwrap().unwrap().i32().unwrap()), - &[Some(4), Some(2), Some(6)] - ); - } - #[test] fn test_replace_or_add() -> PolarsResult<()> { let mut df = df!( diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index d52916364b0e..de85c793681f 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -339,7 +339,7 @@ pub mod checked { } } -pub(crate) fn coerce_lhs_rhs<'a>( +pub fn coerce_lhs_rhs<'a>( lhs: &'a Series, rhs: &'a Series, ) -> PolarsResult<(Cow<'a, Series>, Cow<'a, Series>)> { diff --git a/crates/polars-core/src/series/arithmetic/horizontal.rs b/crates/polars-core/src/series/arithmetic/horizontal.rs deleted file mode 100644 index 94996b7302fe..000000000000 --- a/crates/polars-core/src/series/arithmetic/horizontal.rs +++ /dev/null @@ -1,223 +0,0 @@ -use std::borrow::Cow; - -use polars_error::{polars_bail, PolarsResult}; -use polars_utils::pl_str::PlSmallStr; -use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; - -#[cfg(feature = "zip_with")] -use super::min_max_binary::min_max_binary_columns; -use super::{ - ChunkCompareEq, ChunkSet, Column, DataType, FillNullStrategy, IntoColumn, Scalar, Series, - UInt32Chunked, -}; -use crate::chunked_array::cast::CastOptions; -use crate::frame::NullStrategy; -use crate::POOL; - -/// Aggregate the column horizontally to their min values. -/// -/// All columns need to be the same length or a scalar. -#[cfg(feature = "zip_with")] -pub fn min_horizontal(columns: &[Column]) -> PolarsResult> { - let min_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, true); - - match columns.len() { - 0 => Ok(None), - 1 => Ok(Some(columns[0].clone())), - 2 => min_fn(&columns[0], &columns[1]).map(Some), - _ => { - // the try_reduce_with is a bit slower in parallelism, - // but I don't think it matters here as we parallelize over columns, not over elements - POOL.install(|| { - columns - .par_iter() - .map(|s| Ok(Cow::Borrowed(s))) - .try_reduce_with(|l, r| min_fn(&l, &r).map(Cow::Owned)) - // we can unwrap the option, because we are certain there is a column - // we started this operation on 3 columns - .unwrap() - .map(|cow| Some(cow.into_owned())) - }) - }, - } -} - -/// Aggregate the column horizontally to their max values. -/// -/// All columns need to be the same length or a scalar. -#[cfg(feature = "zip_with")] -pub fn max_horizontal(columns: &[Column]) -> PolarsResult> { - let max_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, false); - - match columns.len() { - 0 => Ok(None), - 1 => Ok(Some(columns[0].clone())), - 2 => max_fn(&columns[0], &columns[1]).map(Some), - _ => { - // the try_reduce_with is a bit slower in parallelism, - // but I don't think it matters here as we parallelize over columns, not over elements - POOL.install(|| { - columns - .par_iter() - .map(|s| Ok(Cow::Borrowed(s))) - .try_reduce_with(|l, r| max_fn(&l, &r).map(Cow::Owned)) - // we can unwrap the option, because we are certain there is a column - // we started this operation on 3 columns - .unwrap() - .map(|cow| Some(cow.into_owned())) - }) - }, - } -} - -/// Sum all values horizontally across columns. -/// -/// All columns need to be the same length or a scalar. -pub fn sum_horizontal( - columns: &[Column], - null_strategy: NullStrategy, -) -> PolarsResult> { - let apply_null_strategy = |s: Series, null_strategy: NullStrategy| -> PolarsResult { - if let NullStrategy::Ignore = null_strategy { - // if has nulls - if s.null_count() > 0 { - return s.fill_null(FillNullStrategy::Zero); - } - } - Ok(s) - }; - - let sum_fn = |acc: Series, s: Series, null_strategy: NullStrategy| -> PolarsResult { - let acc: Series = apply_null_strategy(acc, null_strategy)?; - let s = apply_null_strategy(s, null_strategy)?; - // This will do owned arithmetic and can be mutable - std::ops::Add::add(acc, s) - }; - - // @scalar-opt - let non_null_cols = columns - .iter() - .filter(|x| x.dtype() != &DataType::Null) - .map(|c| c.as_materialized_series()) - .collect::>(); - - match non_null_cols.len() { - 0 => { - if columns.is_empty() { - Ok(None) - } else { - // all columns are null dtype, so result is null dtype - Ok(Some(columns[0].clone())) - } - }, - 1 => Ok(Some( - apply_null_strategy( - if non_null_cols[0].dtype() == &DataType::Boolean { - non_null_cols[0].cast(&DataType::UInt32)? - } else { - non_null_cols[0].clone() - }, - null_strategy, - )? - .into(), - )), - 2 => sum_fn( - non_null_cols[0].clone(), - non_null_cols[1].clone(), - null_strategy, - ) - .map(Column::from) - .map(Some), - _ => { - // the try_reduce_with is a bit slower in parallelism, - // but I don't think it matters here as we parallelize over columns, not over elements - let out = POOL.install(|| { - non_null_cols - .into_par_iter() - .cloned() - .map(Ok) - .try_reduce_with(|l, r| sum_fn(l, r, null_strategy)) - // We can unwrap because we started with at least 3 columns, so we always get a Some - .unwrap() - }); - out.map(Column::from).map(Some) - }, - } -} - -/// Compute the mean of all values horizontally across columns. -/// -/// All columns need to be the same length or a scalar. -pub fn mean_horizontal( - columns: &[Column], - null_strategy: NullStrategy, -) -> PolarsResult> { - let (numeric_columns, non_numeric_columns): (Vec<_>, Vec<_>) = columns.iter().partition(|s| { - let dtype = s.dtype(); - dtype.is_numeric() || dtype.is_decimal() || dtype.is_bool() || dtype.is_null() - }); - - if !non_numeric_columns.is_empty() { - let col = non_numeric_columns.first().cloned(); - polars_bail!( - InvalidOperation: "'horizontal_mean' expects numeric expressions, found {:?} (dtype={})", - col.unwrap().name(), - col.unwrap().dtype(), - ); - } - let columns = numeric_columns.into_iter().cloned().collect::>(); - match columns.len() { - 0 => Ok(None), - 1 => Ok(Some(match columns[0].dtype() { - dt if dt != &DataType::Float32 && !dt.is_decimal() => { - columns[0].cast(&DataType::Float64)? - }, - _ => columns[0].clone(), - })), - _ => { - let sum = || sum_horizontal(&columns, null_strategy); - let null_count = || { - columns - .par_iter() - .map(|c| { - c.is_null() - .into_column() - .cast_with_options(&DataType::UInt32, CastOptions::NonStrict) - }) - .reduce_with(|l, r| { - let l = l?; - let r = r?; - let result = std::ops::Add::add(&l, &r)?; - PolarsResult::Ok(result) - }) - // we can unwrap the option, because we are certain there is a column - // we started this operation on 2 columns - .unwrap() - }; - - let (sum, null_count) = POOL.install(|| rayon::join(sum, null_count)); - let sum = sum?; - let null_count = null_count?; - - // value lengths: len - null_count - let value_length: UInt32Chunked = (Column::new_scalar( - PlSmallStr::EMPTY, - Scalar::from(columns.len() as u32), - null_count.len(), - ) - null_count)? - .u32() - .unwrap() - .clone(); - - // make sure that we do not divide by zero - // by replacing with None - let value_length = value_length - .set(&value_length.equal(0), None)? - .into_column() - .cast(&DataType::Float64)?; - - sum.map(|sum| std::ops::Div::div(&sum, &value_length)) - .transpose() - }, - } -} diff --git a/crates/polars-core/src/series/arithmetic/mod.rs b/crates/polars-core/src/series/arithmetic/mod.rs index 7cb1fd4674f2..8a4d317276c9 100644 --- a/crates/polars-core/src/series/arithmetic/mod.rs +++ b/crates/polars-core/src/series/arithmetic/mod.rs @@ -1,6 +1,5 @@ mod bitops; mod borrowed; -pub mod horizontal; mod list; mod owned; diff --git a/crates/polars-ops/src/series/ops/horizontal.rs b/crates/polars-ops/src/series/ops/horizontal.rs index 7e5cec8c1474..2e96ab27394a 100644 --- a/crates/polars-ops/src/series/ops/horizontal.rs +++ b/crates/polars-ops/src/series/ops/horizontal.rs @@ -1,5 +1,10 @@ -use polars_core::frame::NullStrategy; +use std::borrow::Cow; + +use polars_core::chunked_array::cast::CastOptions; use polars_core::prelude::*; +use polars_core::series::arithmetic::coerce_lhs_rhs; +use polars_core::{with_match_physical_numeric_polars_type, POOL}; +use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; fn validate_column_lengths(cs: &[Column]) -> PolarsResult<()> { let mut length = 1; @@ -16,28 +21,320 @@ fn validate_column_lengths(cs: &[Column]) -> PolarsResult<()> { Ok(()) } -pub fn max_horizontal(s: &[Column]) -> PolarsResult> { - validate_column_lengths(s)?; - polars_core::series::arithmetic::horizontal::max_horizontal(s) - .map(|opt_c| opt_c.map(|res| res.with_name(s[0].name().clone()))) +pub trait MinMaxHorizontal { + /// Aggregate the column horizontally to their min values. + fn min_horizontal(&self) -> PolarsResult>; + /// Aggregate the column horizontally to their max values. + fn max_horizontal(&self) -> PolarsResult>; +} + +impl MinMaxHorizontal for DataFrame { + fn min_horizontal(&self) -> PolarsResult> { + min_horizontal(self.get_columns()) + } + fn max_horizontal(&self) -> PolarsResult> { + max_horizontal(self.get_columns()) + } +} + +#[derive(Copy, Clone, Debug)] +pub enum NullStrategy { + Ignore, + Propagate, +} + +pub trait SumMeanHorizontal { + /// Sum all values horizontally across columns. + fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult>; + + /// Compute the mean of all numeric values horizontally across columns. + fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult>; +} + +impl SumMeanHorizontal for DataFrame { + fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult> { + sum_horizontal(self.get_columns(), null_strategy) + } + fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult> { + mean_horizontal(self.get_columns(), null_strategy) + } +} + +fn min_binary(left: &ChunkedArray, right: &ChunkedArray) -> ChunkedArray +where + T: PolarsNumericType, + T::Native: PartialOrd, +{ + let op = |l: T::Native, r: T::Native| { + if l < r { + l + } else { + r + } + }; + arity::binary_elementwise_values(left, right, op) +} + +fn max_binary(left: &ChunkedArray, right: &ChunkedArray) -> ChunkedArray +where + T: PolarsNumericType, + T::Native: PartialOrd, +{ + let op = |l: T::Native, r: T::Native| { + if l > r { + l + } else { + r + } + }; + arity::binary_elementwise_values(left, right, op) +} + +fn min_max_binary_columns(left: &Column, right: &Column, min: bool) -> PolarsResult { + if left.dtype().to_physical().is_numeric() + && left.null_count() == 0 + && right.null_count() == 0 + && left.len() == right.len() + { + match (left, right) { + (Column::Series(left), Column::Series(right)) => { + let (lhs, rhs) = coerce_lhs_rhs(left, right)?; + let logical = lhs.dtype(); + let lhs = lhs.to_physical_repr(); + let rhs = rhs.to_physical_repr(); + + with_match_physical_numeric_polars_type!(lhs.dtype(), |$T| { + let a: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let b: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + + if min { + min_binary(a, b).into_series().cast(logical) + } else { + max_binary(a, b).into_series().cast(logical) + } + }) + .map(Column::from) + }, + _ => { + let mask = if min { + left.lt(right)? + } else { + left.gt(right)? + }; + + left.zip_with(&mask, right) + }, + } + } else { + let mask = if min { + left.lt(right)? & left.is_not_null() | right.is_null() + } else { + left.gt(right)? & left.is_not_null() | right.is_null() + }; + left.zip_with(&mask, right) + } +} + +pub fn max_horizontal(columns: &[Column]) -> PolarsResult> { + validate_column_lengths(columns)?; + + let max_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, false); + + match columns.len() { + 0 => Ok(None), + 1 => Ok(Some(columns[0].clone())), + 2 => max_fn(&columns[0], &columns[1]).map(Some), + _ => { + // the try_reduce_with is a bit slower in parallelism, + // but I don't think it matters here as we parallelize over columns, not over elements + POOL.install(|| { + columns + .par_iter() + .map(|s| Ok(Cow::Borrowed(s))) + .try_reduce_with(|l, r| max_fn(&l, &r).map(Cow::Owned)) + // we can unwrap the option, because we are certain there is a column + // we started this operation on 3 columns + .unwrap() + .map(|cow| Some(cow.into_owned())) + }) + }, + } } -pub fn min_horizontal(s: &[Column]) -> PolarsResult> { - validate_column_lengths(s)?; - polars_core::series::arithmetic::horizontal::min_horizontal(s) - .map(|opt_c| opt_c.map(|res| res.with_name(s[0].name().clone()))) +pub fn min_horizontal(columns: &[Column]) -> PolarsResult> { + validate_column_lengths(columns)?; + + let min_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, true); + + match columns.len() { + 0 => Ok(None), + 1 => Ok(Some(columns[0].clone())), + 2 => min_fn(&columns[0], &columns[1]).map(Some), + _ => { + // the try_reduce_with is a bit slower in parallelism, + // but I don't think it matters here as we parallelize over columns, not over elements + POOL.install(|| { + columns + .par_iter() + .map(|s| Ok(Cow::Borrowed(s))) + .try_reduce_with(|l, r| min_fn(&l, &r).map(Cow::Owned)) + // we can unwrap the option, because we are certain there is a column + // we started this operation on 3 columns + .unwrap() + .map(|cow| Some(cow.into_owned())) + }) + }, + } } -pub fn sum_horizontal(s: &[Column], null_strategy: NullStrategy) -> PolarsResult> { - validate_column_lengths(s)?; - polars_core::series::arithmetic::horizontal::sum_horizontal(s, null_strategy) - .map(|opt_c| opt_c.map(|res| res.with_name(s[0].name().clone()))) +pub fn sum_horizontal( + columns: &[Column], + null_strategy: NullStrategy, +) -> PolarsResult> { + validate_column_lengths(columns)?; + + let apply_null_strategy = |s: Series, null_strategy: NullStrategy| -> PolarsResult { + if let NullStrategy::Ignore = null_strategy { + // if has nulls + if s.null_count() > 0 { + return s.fill_null(FillNullStrategy::Zero); + } + } + Ok(s) + }; + + let sum_fn = |acc: Series, s: Series, null_strategy: NullStrategy| -> PolarsResult { + let acc: Series = apply_null_strategy(acc, null_strategy)?; + let s = apply_null_strategy(s, null_strategy)?; + // This will do owned arithmetic and can be mutable + std::ops::Add::add(acc, s) + }; + + // @scalar-opt + let non_null_cols = columns + .iter() + .filter(|x| x.dtype() != &DataType::Null) + .map(|c| c.as_materialized_series()) + .collect::>(); + + match non_null_cols.len() { + 0 => { + if columns.is_empty() { + Ok(None) + } else { + // all columns are null dtype, so result is null dtype + Ok(Some(columns[0].clone())) + } + }, + 1 => Ok(Some( + apply_null_strategy( + if non_null_cols[0].dtype() == &DataType::Boolean { + non_null_cols[0].cast(&DataType::UInt32)? + } else { + non_null_cols[0].clone() + }, + null_strategy, + )? + .into(), + )), + 2 => sum_fn( + non_null_cols[0].clone(), + non_null_cols[1].clone(), + null_strategy, + ) + .map(Column::from) + .map(Some), + _ => { + // the try_reduce_with is a bit slower in parallelism, + // but I don't think it matters here as we parallelize over columns, not over elements + let out = POOL.install(|| { + non_null_cols + .into_par_iter() + .cloned() + .map(Ok) + .try_reduce_with(|l, r| sum_fn(l, r, null_strategy)) + // We can unwrap because we started with at least 3 columns, so we always get a Some + .unwrap() + }); + out.map(Column::from).map(Some) + }, + } } -pub fn mean_horizontal(s: &[Column], null_strategy: NullStrategy) -> PolarsResult> { - validate_column_lengths(s)?; - polars_core::series::arithmetic::horizontal::mean_horizontal(s, null_strategy) - .map(|opt_c| opt_c.map(|res| res.with_name(s[0].name().clone()))) +pub fn mean_horizontal( + columns: &[Column], + null_strategy: NullStrategy, +) -> PolarsResult> { + validate_column_lengths(columns)?; + + let (numeric_columns, non_numeric_columns): (Vec<_>, Vec<_>) = columns.iter().partition(|s| { + let dtype = s.dtype(); + dtype.is_numeric() || dtype.is_decimal() || dtype.is_bool() || dtype.is_null() + }); + + if !non_numeric_columns.is_empty() { + let col = non_numeric_columns.first().cloned(); + polars_bail!( + InvalidOperation: "'horizontal_mean' expects numeric expressions, found {:?} (dtype={})", + col.unwrap().name(), + col.unwrap().dtype(), + ); + } + let columns = numeric_columns.into_iter().cloned().collect::>(); + match columns.len() { + 0 => Ok(None), + 1 => Ok(Some(match columns[0].dtype() { + dt if dt != &DataType::Float32 && !dt.is_decimal() => { + columns[0].cast(&DataType::Float64)? + }, + _ => columns[0].clone(), + })), + _ => { + let sum = || sum_horizontal(columns.as_slice(), null_strategy); + let null_count = || { + columns + .par_iter() + .map(|c| { + c.is_null() + .into_column() + .cast_with_options(&DataType::UInt32, CastOptions::NonStrict) + }) + .reduce_with(|l, r| { + let l = l?; + let r = r?; + let result = std::ops::Add::add(&l, &r)?; + PolarsResult::Ok(result) + }) + // we can unwrap the option, because we are certain there is a column + // we started this operation on 2 columns + .unwrap() + }; + + let (sum, null_count) = POOL.install(|| rayon::join(sum, null_count)); + let sum = sum?; + let null_count = null_count?; + + // value lengths: len - null_count + let value_length: UInt32Chunked = (Column::new_scalar( + PlSmallStr::EMPTY, + Scalar::from(columns.len() as u32), + null_count.len(), + ) - null_count)? + .u32() + .unwrap() + .clone(); + + // make sure that we do not divide by zero + // by replacing with None + let value_length = value_length + .set(&value_length.equal(0), None)? + .into_column() + .cast(&DataType::Float64)?; + + sum.map(|sum| std::ops::Div::div(&sum, &value_length)) + .transpose() + }, + } } pub fn coalesce_columns(s: &[Column]) -> PolarsResult { @@ -57,3 +354,46 @@ pub fn coalesce_columns(s: &[Column]) -> PolarsResult { } Ok(out) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[cfg_attr(miri, ignore)] + fn test_horizontal_agg() { + let a = Column::new("a".into(), [1, 2, 6]); + let b = Column::new("b".into(), [Some(1), None, None]); + let c = Column::new("c".into(), [Some(4), None, Some(3)]); + + let df = DataFrame::new(vec![a, b, c]).unwrap(); + assert_eq!( + Vec::from( + df.mean_horizontal(NullStrategy::Ignore) + .unwrap() + .unwrap() + .f64() + .unwrap() + ), + &[Some(2.0), Some(2.0), Some(4.5)] + ); + assert_eq!( + Vec::from( + df.sum_horizontal(NullStrategy::Ignore) + .unwrap() + .unwrap() + .i32() + .unwrap() + ), + &[Some(6), Some(2), Some(9)] + ); + assert_eq!( + Vec::from(df.min_horizontal().unwrap().unwrap().i32().unwrap()), + &[Some(1), Some(2), Some(3)] + ); + assert_eq!( + Vec::from(df.max_horizontal().unwrap().unwrap().i32().unwrap()), + &[Some(4), Some(2), Some(6)] + ); + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/dispatch.rs b/crates/polars-plan/src/dsl/function_expr/dispatch.rs index 7a338553e1e1..573685d16177 100644 --- a/crates/polars-plan/src/dsl/function_expr/dispatch.rs +++ b/crates/polars-plan/src/dsl/function_expr/dispatch.rs @@ -1,4 +1,4 @@ -use polars_core::frame::NullStrategy; +use polars_ops::series::NullStrategy; use super::*; diff --git a/crates/polars-python/src/conversion/mod.rs b/crates/polars-python/src/conversion/mod.rs index ac3286aaab36..c99e7fcfecf2 100644 --- a/crates/polars-python/src/conversion/mod.rs +++ b/crates/polars-python/src/conversion/mod.rs @@ -9,7 +9,6 @@ use std::path::PathBuf; #[cfg(feature = "object")] use polars::chunked_array::object::PolarsObjectSafe; use polars::frame::row::Row; -use polars::frame::NullStrategy; #[cfg(feature = "avro")] use polars::io::avro::AvroCompression; #[cfg(feature = "cloud")] From ae05c687cb8b6ff14e07f42f0167933686ca9c2a Mon Sep 17 00:00:00 2001 From: Luke Manley Date: Tue, 3 Dec 2024 10:56:56 -0500 Subject: [PATCH 18/20] docs: Fix Rust examples in user guide (#20075) --- Cargo.lock | 105 ++++++++++++++++++ docs/source/src/rust/Cargo.toml | 45 +++++--- .../rust/user-guide/concepts/lazy-vs-eager.rs | 2 +- .../src/rust/user-guide/expressions/window.rs | 50 +++++++-- .../src/rust/user-guide/getting-started.rs | 4 +- docs/source/src/rust/user-guide/io/csv.rs | 20 ++-- docs/source/src/rust/user-guide/io/json.rs | 28 ++--- docs/source/src/rust/user-guide/io/parquet.rs | 16 +-- .../rust/user-guide/transformations/joins.rs | 8 +- 9 files changed, 213 insertions(+), 65 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c2061e769593..6a8d9638a14b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1434,6 +1434,21 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "foreign_vec" version = "0.1.0" @@ -1881,6 +1896,22 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper 1.5.1", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.10" @@ -2440,6 +2471,23 @@ dependencies = [ "target-features", ] +[[package]] +name = "native-tls" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework 2.11.1", + "security-framework-sys", + "tempfile", +] + [[package]] name = "ndarray" version = "0.16.1" @@ -2684,12 +2732,50 @@ version = "11.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +[[package]] +name = "openssl" +version = "0.10.68" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "openssl-probe" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +[[package]] +name = "openssl-sys" +version = "0.9.104" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "outref" version = "0.5.1" @@ -3939,11 +4025,13 @@ dependencies = [ "http-body-util", "hyper 1.5.1", "hyper-rustls 0.27.3", + "hyper-tls", "hyper-util", "ipnet", "js-sys", "log", "mime", + "native-tls", "once_cell", "percent-encoding", "pin-project-lite", @@ -3957,6 +4045,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper", "tokio", + "tokio-native-tls", "tokio-rustls 0.26.0", "tokio-util", "tower-service", @@ -4786,6 +4875,16 @@ dependencies = [ "syn 2.0.90", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.24.1" @@ -4996,6 +5095,12 @@ dependencies = [ "ryu", ] +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.5" diff --git a/docs/source/src/rust/Cargo.toml b/docs/source/src/rust/Cargo.toml index 8ac668bf82c0..c87e87cffffa 100644 --- a/docs/source/src/rust/Cargo.toml +++ b/docs/source/src/rust/Cargo.toml @@ -14,7 +14,7 @@ aws-sdk-s3 = { version = "1" } aws-smithy-checksums = { version = "0.60.10" } chrono = { workspace = true } rand = { workspace = true } -reqwest = { workspace = true, features = ["blocking"] } +reqwest = { workspace = true, features = ["blocking", "default-tls"] } tokio = { workspace = true } [dependencies.polars] @@ -28,11 +28,12 @@ required-features = ["polars/lazy", "polars/csv"] [[bin]] name = "getting-started" path = "user-guide/getting-started.rs" -required-features = ["polars/lazy", "polars/temporal", "polars/round_series", "polars/strings"] +required-features = ["polars/lazy", "polars/temporal", "polars/round_series", "polars/strings", "polars/is_between"] [[bin]] name = "concepts-data-types-and-structures" path = "user-guide/concepts/data-types-and-structures.rs" +required-features = ["polars/lazy", "polars/temporal"] [[bin]] name = "concepts-expressions" @@ -45,12 +46,12 @@ required-features = ["polars/lazy", "polars/csv"] [[bin]] name = "concepts-streaming" path = "user-guide/concepts/streaming.rs" -required-features = ["polars/lazy", "polars/csv"] +required-features = ["polars/lazy", "polars/csv", "polars/streaming"] [[bin]] name = "expressions-aggregation" path = "user-guide/expressions/aggregation.rs" -required-features = ["polars/lazy"] +required-features = ["polars/lazy", "polars/csv", "polars/temporal", "polars/dtype-categorical"] [[bin]] name = "expressions-casting" path = "user-guide/expressions/casting.rs" @@ -58,11 +59,11 @@ required-features = ["polars/lazy", "polars/temporal", "polars/strings", "polars [[bin]] name = "expressions-column-selections" path = "user-guide/expressions/column-selections.rs" -required-features = ["polars/lazy"] +required-features = ["polars/lazy", "polars/temporal", "polars/regex"] [[bin]] name = "expressions-folds" path = "user-guide/expressions/folds.rs" -required-features = ["polars/lazy"] +required-features = ["polars/lazy", "polars/strings", "polars/concat_str", "polars/temporal"] [[bin]] name = "expressions-expression-expansion" path = "user-guide/expressions/expression-expansion.rs" @@ -86,11 +87,11 @@ required-features = ["polars/lazy"] [[bin]] name = "expressions-structs" path = "user-guide/expressions/structs.rs" -required-features = ["polars/lazy"] +required-features = ["polars/lazy", "polars/dtype-struct", "polars/rank", "polars/strings", "polars/temporal"] [[bin]] name = "expressions-window" path = "user-guide/expressions/window.rs" -required-features = ["polars/lazy"] +required-features = ["polars/lazy", "polars/csv", "polars/rank"] [[bin]] name = "io-cloud-storage" @@ -99,24 +100,32 @@ required-features = ["polars/csv"] [[bin]] name = "io-csv" path = "user-guide/io/csv.rs" -required-features = ["polars/csv"] +required-features = ["polars/lazy", "polars/csv"] [[bin]] name = "io-json" path = "user-guide/io/json.rs" -required-features = ["polars/json"] +required-features = ["polars/lazy", "polars/json"] [[bin]] name = "io-parquet" path = "user-guide/io/parquet.rs" -required-features = ["polars/parquet"] +required-features = ["polars/lazy", "polars/parquet"] [[bin]] name = "transformations-concatenation" path = "user-guide/transformations/concatenation.rs" -required-features = ["polars/lazy"] +required-features = ["polars/lazy", "polars/diagonal_concat"] [[bin]] name = "transformations-joins" path = "user-guide/transformations/joins.rs" -required-features = ["polars/lazy", "polars/strings", "polars/semi_anti_join", "polars/iejoin", "polars/cross_join"] +required-features = [ + "polars/lazy", + "polars/strings", + "polars/semi_anti_join", + "polars/iejoin", + "polars/cross_join", + "polars/temporal", + "polars/asof_join", +] [[bin]] name = "transformations-unpivot" path = "user-guide/transformations/unpivot.rs" @@ -129,20 +138,20 @@ required-features = ["polars/lazy", "polars/pivot"] [[bin]] name = "transformations-time-series-filter" path = "user-guide/transformations/time-series/filter.rs" -required-features = ["polars/lazy"] +required-features = ["polars/lazy", "polars/strings", "polars/temporal"] [[bin]] name = "transformations-time-series-parsing" path = "user-guide/transformations/time-series/parsing.rs" -required-features = ["polars/lazy"] +required-features = ["polars/lazy", "polars/strings", "polars/temporal", "polars/timezones"] [[bin]] name = "transformations-time-series-resampling" path = "user-guide/transformations/time-series/resampling.rs" -required-features = ["polars/lazy"] +required-features = ["polars/lazy", "polars/temporal", "polars/interpolate"] [[bin]] name = "transformations-time-series-rolling" path = "user-guide/transformations/time-series/rolling.rs" -required-features = ["polars/lazy"] +required-features = ["polars/lazy", "polars/temporal", "polars/dynamic_group_by", "polars/cum_agg"] [[bin]] name = "transformations-time-series-timezones" path = "user-guide/transformations/time-series/timezones.rs" -required-features = ["polars/lazy"] +required-features = ["polars/lazy", "polars/temporal", "polars/timezones", "polars/strings"] diff --git a/docs/source/src/rust/user-guide/concepts/lazy-vs-eager.rs b/docs/source/src/rust/user-guide/concepts/lazy-vs-eager.rs index 955111ac2c11..89f3a6610748 100644 --- a/docs/source/src/rust/user-guide/concepts/lazy-vs-eager.rs +++ b/docs/source/src/rust/user-guide/concepts/lazy-vs-eager.rs @@ -35,7 +35,7 @@ fn main() -> Result<(), Box> { .filter(col("sepal_length").gt(lit(5))) .group_by(vec![col("species")]) .agg([col("sepal_width").mean()]); - println!("{:?}", q.explain(true)); + println!("{}", q.explain(true)?); // --8<-- [end:explain] Ok(()) diff --git a/docs/source/src/rust/user-guide/expressions/window.rs b/docs/source/src/rust/user-guide/expressions/window.rs index 891f99793d29..9579b3c0adb7 100644 --- a/docs/source/src/rust/user-guide/expressions/window.rs +++ b/docs/source/src/rust/user-guide/expressions/window.rs @@ -16,11 +16,30 @@ fn main() -> Result<(), Box> { .into_reader_with_file_handle(file) .finish()?; - println!("{}", df); + println!("{}", df.head(Some(5))); // --8<-- [end:pokemon] // --8<-- [start:rank] - // Contribute the Rust translation of the Python example by opening a PR. + let result = df + .clone() + .lazy() + .select([ + col("Name"), + col("Type 1"), + col("Speed") + .rank( + RankOptions { + method: RankMethod::Dense, + descending: true, + }, + None, + ) + .over(["Type 1"]) + .alias("Speed rank"), + ]) + .collect()?; + + println!("{}", result); // --8<-- [end:rank] // --8<-- [start:rank-multiple] @@ -48,7 +67,21 @@ fn main() -> Result<(), Box> { // --8<-- [end:athletes-join] // --8<-- [start:pokemon-mean] - // Contribute the Rust translation of the Python example by opening a PR. + let result = df + .clone() + .lazy() + .select([ + col("Name"), + col("Type 1"), + col("Speed"), + col("Speed") + .mean() + .over(["Type 1"]) + .alias("Mean speed in group"), + ]) + .collect()?; + + println!("{}", result); // --8<-- [end:pokemon-mean] // --8<-- [start:group_by] @@ -102,14 +135,17 @@ fn main() -> Result<(), Box> { .clone() .lazy() .select([ - col("Type 1").head(Some(3)).over(["Type 1"]).flatten(), + col("Type 1") + .head(Some(3)) + .over_with_options(["Type 1"], None, WindowMapping::Explode) + .flatten(), col("Name") .sort_by( ["Speed"], SortMultipleOptions::default().with_order_descending(true), ) .head(Some(3)) - .over(["Type 1"]) + .over_with_options(["Type 1"], None, WindowMapping::Explode) .flatten() .alias("fastest/group"), col("Name") @@ -118,13 +154,13 @@ fn main() -> Result<(), Box> { SortMultipleOptions::default().with_order_descending(true), ) .head(Some(3)) - .over(["Type 1"]) + .over_with_options(["Type 1"], None, WindowMapping::Explode) .flatten() .alias("strongest/group"), col("Name") .sort(Default::default()) .head(Some(3)) - .over(["Type 1"]) + .over_with_options(["Type 1"], None, WindowMapping::Explode) .flatten() .alias("sorted_by_alphabet"), ]) diff --git a/docs/source/src/rust/user-guide/getting-started.rs b/docs/source/src/rust/user-guide/getting-started.rs index 362c99b533c9..550b0b0932e2 100644 --- a/docs/source/src/rust/user-guide/getting-started.rs +++ b/docs/source/src/rust/user-guide/getting-started.rs @@ -21,7 +21,7 @@ fn main() -> Result<(), Box> { // --8<-- [start:csv] use std::fs::File; - let mut file = File::create("../../../assets/data/output.csv").expect("could not create file"); + let mut file = File::create("docs/assets/data/output.csv").expect("could not create file"); CsvWriter::new(&mut file) .include_header(true) .with_separator(b',') @@ -30,7 +30,7 @@ fn main() -> Result<(), Box> { .with_infer_schema_length(None) .with_has_header(true) .with_parse_options(CsvParseOptions::default().with_try_parse_dates(true)) - .try_into_reader_with_file_path(Some("../../../assets/data/output.csv".into()))? + .try_into_reader_with_file_path(Some("docs/assets/data/output.csv".into()))? .finish()?; println!("{}", df_csv); // --8<-- [end:csv] diff --git a/docs/source/src/rust/user-guide/io/csv.rs b/docs/source/src/rust/user-guide/io/csv.rs index e9a624895399..1406a9e098cc 100644 --- a/docs/source/src/rust/user-guide/io/csv.rs +++ b/docs/source/src/rust/user-guide/io/csv.rs @@ -4,14 +4,6 @@ fn main() -> Result<(), Box> { // --8<-- [start:read] use polars::prelude::*; - let df = CsvReadOptions::default() - .try_into_reader_with_file_path(Some("docs/assets/data/path.csv".into())) - .unwrap() - .finish() - .unwrap(); - // --8<-- [end:read] - println!("{}", df); - // --8<-- [start:write] let mut df = df!( "foo" => &[1, 2, 3], @@ -23,8 +15,18 @@ fn main() -> Result<(), Box> { CsvWriter::new(&mut file).finish(&mut df).unwrap(); // --8<-- [end:write] + let df = CsvReadOptions::default() + .try_into_reader_with_file_path(Some("docs/assets/data/path.csv".into())) + .unwrap() + .finish() + .unwrap(); + // --8<-- [end:read] + println!("{}", df); + // --8<-- [start:scan] - let lf = LazyCsvReader::new("./test.csv").finish().unwrap(); + let lf = LazyCsvReader::new("docs/assets/data/path.csv") + .finish() + .unwrap(); // --8<-- [end:scan] println!("{}", lf.collect()?); diff --git a/docs/source/src/rust/user-guide/io/json.rs b/docs/source/src/rust/user-guide/io/json.rs index da039cb8cb00..468e20babe8b 100644 --- a/docs/source/src/rust/user-guide/io/json.rs +++ b/docs/source/src/rust/user-guide/io/json.rs @@ -1,20 +1,6 @@ use polars::prelude::*; fn main() -> Result<(), Box> { - // --8<-- [start:read] - use polars::prelude::*; - - let mut file = std::fs::File::open("docs/assets/data/path.json").unwrap(); - let df = JsonReader::new(&mut file).finish().unwrap(); - // --8<-- [end:read] - println!("{}", df); - - // --8<-- [start:readnd] - let mut file = std::fs::File::open("docs/assets/data/path.json").unwrap(); - let df = JsonLineReader::new(&mut file).finish().unwrap(); - // --8<-- [end:readnd] - println!("{}", df); - // --8<-- [start:write] let mut df = df!( "foo" => &[1, 2, 3], @@ -37,6 +23,20 @@ fn main() -> Result<(), Box> { .unwrap(); // --8<-- [end:write] + // --8<-- [start:read] + use polars::prelude::*; + + let mut file = std::fs::File::open("docs/assets/data/path.json").unwrap(); + let df = JsonReader::new(&mut file).finish()?; + // --8<-- [end:read] + println!("{}", df); + + // --8<-- [start:readnd] + let mut file = std::fs::File::open("docs/assets/data/path.json").unwrap(); + let df = JsonLineReader::new(&mut file).finish().unwrap(); + // --8<-- [end:readnd] + println!("{}", df); + // --8<-- [start:scan] let lf = LazyJsonLineReader::new("docs/assets/data/path.json") .finish() diff --git a/docs/source/src/rust/user-guide/io/parquet.rs b/docs/source/src/rust/user-guide/io/parquet.rs index fd340fabf222..a554c7051040 100644 --- a/docs/source/src/rust/user-guide/io/parquet.rs +++ b/docs/source/src/rust/user-guide/io/parquet.rs @@ -1,13 +1,6 @@ use polars::prelude::*; fn main() -> Result<(), Box> { - // --8<-- [start:read] - let mut file = std::fs::File::open("docs/assets/data/path.parquet").unwrap(); - - let df = ParquetReader::new(&mut file).finish().unwrap(); - // --8<-- [end:read] - println!("{}", df); - // --8<-- [start:write] let mut df = df!( "foo" => &[1, 2, 3], @@ -19,9 +12,16 @@ fn main() -> Result<(), Box> { ParquetWriter::new(&mut file).finish(&mut df).unwrap(); // --8<-- [end:write] + // --8<-- [start:read] + let mut file = std::fs::File::open("docs/assets/data/path.parquet").unwrap(); + + let df = ParquetReader::new(&mut file).finish().unwrap(); + // --8<-- [end:read] + println!("{}", df); + // --8<-- [start:scan] let args = ScanArgsParquet::default(); - let lf = LazyFrame::scan_parquet("./file.parquet", args).unwrap(); + let lf = LazyFrame::scan_parquet("docs/assets/data/path.parquet", args).unwrap(); // --8<-- [end:scan] println!("{}", lf.collect()?); diff --git a/docs/source/src/rust/user-guide/transformations/joins.rs b/docs/source/src/rust/user-guide/transformations/joins.rs index 5d1c50f733b1..c57fa95460a1 100644 --- a/docs/source/src/rust/user-guide/transformations/joins.rs +++ b/docs/source/src/rust/user-guide/transformations/joins.rs @@ -9,9 +9,7 @@ fn main() -> Result<(), Box> { // --8<-- [start:props_groups] let props_groups = CsvReadOptions::default() .with_has_header(true) - .try_into_reader_with_file_path(Some( - "../../../assets/data/monopoly_props_groups.csv".into(), - ))? + .try_into_reader_with_file_path(Some("docs/assets/data/monopoly_props_groups.csv".into()))? .finish()? .head(Some(5)); println!("{}", props_groups); @@ -20,9 +18,7 @@ fn main() -> Result<(), Box> { // --8<-- [start:props_prices] let props_prices = CsvReadOptions::default() .with_has_header(true) - .try_into_reader_with_file_path(Some( - "../../../assets/data/monopoly_props_prices.csv".into(), - ))? + .try_into_reader_with_file_path(Some("docs/assets/data/monopoly_props_prices.csv".into()))? .finish()? .head(Some(5)); println!("{}", props_prices); From 875016c28e22b21168069e7a1e7fc11cdca9dc7b Mon Sep 17 00:00:00 2001 From: Henry Harbeck <59268910+henryharbeck@users.noreply.github.com> Date: Wed, 4 Dec 2024 02:42:19 +1000 Subject: [PATCH 19/20] test(python): Fix typo in assertion in datatype copy test (#20121) Co-authored-by: Henry Harbeck --- py-polars/tests/unit/datatypes/test_datatype.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py-polars/tests/unit/datatypes/test_datatype.py b/py-polars/tests/unit/datatypes/test_datatype.py index 0fe164ded8c4..804fb52829b6 100644 --- a/py-polars/tests/unit/datatypes/test_datatype.py +++ b/py-polars/tests/unit/datatypes/test_datatype.py @@ -7,5 +7,5 @@ def test_datatype_copy() -> None: dtype = pl.Int64() result = copy.deepcopy(dtype) - assert dtype == dtype + assert dtype == result assert isinstance(result, pl.Int64) From bcfa7ec9d99d373642726ca9d90bffec3fee4d5f Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Wed, 4 Dec 2024 19:27:49 +1100 Subject: [PATCH 20/20] feat: Experimental cloud write support (#20129) --- crates/polars-io/src/cloud/adaptors.rs | 112 +++++-- crates/polars-io/src/cloud/options.rs | 2 +- crates/polars-io/src/path_utils/mod.rs | 4 +- crates/polars-io/src/pl_async.rs | 19 +- crates/polars-io/src/utils/file.rs | 81 +++++ crates/polars-io/src/utils/mod.rs | 1 + crates/polars-lazy/src/frame/mod.rs | 26 +- .../src/executors/sinks/output/csv.rs | 16 +- .../src/executors/sinks/output/ipc.rs | 12 +- .../src/executors/sinks/output/json.rs | 11 +- .../src/executors/sinks/output/parquet.rs | 23 +- crates/polars-pipe/src/pipeline/convert.rs | 10 +- crates/polars-plan/src/plans/options.rs | 1 + crates/polars-python/src/dataframe/io.rs | 97 +++++- crates/polars-python/src/file.rs | 299 ++++++++---------- crates/polars-python/src/lazyframe/general.rs | 94 +++++- .../src/physical_plan/lower_ir.rs | 6 +- crates/polars-utils/src/io.rs | 9 + crates/polars-utils/src/mmap.rs | 11 +- py-polars/polars/dataframe/frame.py | 226 +++++++++++-- py-polars/polars/io/cloud/_utils.py | 3 +- .../polars/io/cloud/credential_provider.py | 3 +- py-polars/polars/lazyframe/frame.py | 182 ++++++++++- py-polars/tests/unit/io/test_write.py | 74 +++++ .../tests/unit/streaming/test_streaming_io.py | 6 + 25 files changed, 1035 insertions(+), 293 deletions(-) create mode 100644 crates/polars-io/src/utils/file.rs create mode 100644 py-polars/tests/unit/io/test_write.py diff --git a/crates/polars-io/src/cloud/adaptors.rs b/crates/polars-io/src/cloud/adaptors.rs index 5e034b55a80c..d23d36d23383 100644 --- a/crates/polars-io/src/cloud/adaptors.rs +++ b/crates/polars-io/src/cloud/adaptors.rs @@ -8,8 +8,33 @@ use object_store::ObjectStore; use polars_error::{to_compute_err, PolarsResult}; use tokio::io::AsyncWriteExt; -use super::CloudOptions; -use crate::pl_async::get_runtime; +use super::{object_path_from_str, CloudOptions}; +use crate::pl_async::{get_runtime, get_upload_chunk_size}; + +enum WriterState { + Open(BufWriter), + /// Note: `Err` state is also used as the close state on success. + Err(std::io::Error), +} + +impl WriterState { + fn try_with_writer(&mut self, func: F) -> std::io::Result + where + F: Fn(&mut BufWriter) -> std::io::Result, + { + match self { + Self::Open(writer) => match func(writer) { + Ok(v) => Ok(v), + Err(e) => { + let _ = get_runtime().block_on_potential_spawn(writer.abort()); + *self = Self::Err(e); + self.try_with_writer(func) + }, + }, + Self::Err(e) => Err(std::io::Error::new(e.kind(), e.to_string())), + } + } +} /// Adaptor which wraps the interface of [ObjectStore::BufWriter] exposing a synchronous interface /// which implements `std::io::Write`. @@ -20,7 +45,7 @@ use crate::pl_async::get_runtime; /// [ObjectStore::BufWriter]: https://docs.rs/object_store/latest/object_store/buffered/struct.BufWriter.html pub struct CloudWriter { // Internal writer, constructed at creation - writer: BufWriter, + inner: WriterState, } impl CloudWriter { @@ -33,8 +58,10 @@ impl CloudWriter { object_store: Arc, path: Path, ) -> PolarsResult { - let writer = BufWriter::new(object_store, path); - Ok(CloudWriter { writer }) + let writer = BufWriter::with_capacity(object_store, path, get_upload_chunk_size()); + Ok(CloudWriter { + inner: WriterState::Open(writer), + }) } /// Constructs a new CloudWriter from a path and an optional set of CloudOptions. @@ -42,13 +69,36 @@ impl CloudWriter { /// Wrapper around `CloudWriter::new_with_object_store` that is useful if you only have a single write task. /// TODO: Naming? pub async fn new(uri: &str, cloud_options: Option<&CloudOptions>) -> PolarsResult { + if let Some(local_path) = uri.strip_prefix("file://") { + // Local paths must be created first, otherwise object store will not write anything. + if !matches!(std::fs::exists(local_path), Ok(true)) { + panic!( + "[CloudWriter] Expected local file to be created: {}", + local_path + ); + } + } + let (cloud_location, object_store) = crate::cloud::build_object_store(uri, cloud_options, false).await?; - Self::new_with_object_store(object_store, cloud_location.prefix.into()) + Self::new_with_object_store(object_store, object_path_from_str(&cloud_location.prefix)?) } - async fn abort(&mut self) -> PolarsResult<()> { - self.writer.abort().await.map_err(to_compute_err) + pub fn close(&mut self) -> PolarsResult<()> { + let WriterState::Open(writer) = &mut self.inner else { + panic!(); + }; + + get_runtime() + .block_on_potential_spawn(async { writer.shutdown().await }) + .map_err(to_compute_err)?; + + self.inner = WriterState::Err(std::io::Error::new( + std::io::ErrorKind::Other, + "impl error: file was closed", + )); + + Ok(()) } } @@ -58,29 +108,27 @@ impl std::io::Write for CloudWriter { // We extend the lifetime for the duration of this function. This is safe as well block the // async runtime here let buf = unsafe { std::mem::transmute::<&[u8], &'static [u8]>(buf) }; - get_runtime().block_on_potential_spawn(async { - let res = self.writer.write_all(buf).await; - if res.is_err() { - let _ = self.abort().await; - } - res.map(|_t| buf.len()) + + self.inner.try_with_writer(|writer| { + get_runtime() + .block_on_potential_spawn(async { writer.write_all(buf).await.map(|_t| buf.len()) }) }) } fn flush(&mut self) -> std::io::Result<()> { - get_runtime().block_on_potential_spawn(async { - let res = self.writer.flush().await; - if res.is_err() { - let _ = self.abort().await; - } - res + self.inner.try_with_writer(|writer| { + get_runtime().block_on_potential_spawn(async { writer.flush().await }) }) } } impl Drop for CloudWriter { fn drop(&mut self) { - let _ = get_runtime().block_on_potential_spawn(self.writer.shutdown()); + // TODO: Properly raise this error instead of panicking. + match self.inner { + WriterState::Open(_) => self.close().unwrap(), + WriterState::Err(_) => {}, + } } } @@ -91,6 +139,8 @@ mod tests { use polars_core::prelude::DataFrame; use super::*; + use crate::prelude::CsvReadOptions; + use crate::SerReader; fn example_dataframe() -> DataFrame { df!( @@ -129,15 +179,27 @@ mod tests { let mut df = example_dataframe(); + let path = "/tmp/cloud_writer_example2.csv"; + + std::fs::File::create(path).unwrap(); + let mut cloud_writer = get_runtime() - .block_on(CloudWriter::new( - "file:///tmp/cloud_writer_example2.csv", - None, - )) + .block_on(CloudWriter::new(format!("file://{}", path).as_str(), None)) .unwrap(); CsvWriter::new(&mut cloud_writer) .finish(&mut df) .expect("Could not write DataFrame as CSV to remote location"); + + cloud_writer.close().unwrap(); + + assert_eq!( + CsvReadOptions::default() + .try_into_reader_with_file_path(Some(path.into())) + .unwrap() + .finish() + .unwrap(), + df + ); } } diff --git a/crates/polars-io/src/cloud/options.rs b/crates/polars-io/src/cloud/options.rs index 545ab5550890..d45f2a3166a7 100644 --- a/crates/polars-io/src/cloud/options.rs +++ b/crates/polars-io/src/cloud/options.rs @@ -567,7 +567,7 @@ impl CloudOptions { let hf_home = std::env::var("HF_HOME"); let hf_home = hf_home.as_deref(); let hf_home = hf_home.unwrap_or("~/.cache/huggingface"); - let hf_home = resolve_homedir(std::path::Path::new(&hf_home)); + let hf_home = resolve_homedir(&hf_home); let cached_token_path = hf_home.join("token"); let v = std::string::String::from_utf8( diff --git a/crates/polars-io/src/path_utils/mod.rs b/crates/polars-io/src/path_utils/mod.rs index a8a8f1abb3f5..6817de4a71cf 100644 --- a/crates/polars-io/src/path_utils/mod.rs +++ b/crates/polars-io/src/path_utils/mod.rs @@ -35,7 +35,9 @@ pub static POLARS_TEMP_DIR_BASE_PATH: Lazy> = Lazy::new(|| { }); /// Replaces a "~" in the Path with the home directory. -pub fn resolve_homedir(path: &Path) -> PathBuf { +pub fn resolve_homedir(path: &dyn AsRef) -> PathBuf { + let path = path.as_ref(); + if path.starts_with("~") { // home crate does not compile on wasm https://github.com/rust-lang/cargo/issues/12297 #[cfg(not(target_family = "wasm"))] diff --git a/crates/polars-io/src/pl_async.rs b/crates/polars-io/src/pl_async.rs index 08540ee09449..5b7af3d1684e 100644 --- a/crates/polars-io/src/pl_async.rs +++ b/crates/polars-io/src/pl_async.rs @@ -14,7 +14,7 @@ pub(super) const MAX_BUDGET_PER_REQUEST: usize = 10; /// Used to determine chunks when splitting large ranges, or combining small /// ranges. -pub(super) static DOWNLOAD_CHUNK_SIZE: Lazy = Lazy::new(|| { +static DOWNLOAD_CHUNK_SIZE: Lazy = Lazy::new(|| { let v: usize = std::env::var("POLARS_DOWNLOAD_CHUNK_SIZE") .as_deref() .map(|x| x.parse().expect("integer")) @@ -31,6 +31,23 @@ pub(super) fn get_download_chunk_size() -> usize { *DOWNLOAD_CHUNK_SIZE } +static UPLOAD_CHUNK_SIZE: Lazy = Lazy::new(|| { + let v: usize = std::env::var("POLARS_UPLOAD_CHUNK_SIZE") + .as_deref() + .map(|x| x.parse().expect("integer")) + .unwrap_or(64 * 1024 * 1024); + + if config::verbose() { + eprintln!("async upload_chunk_size: {}", v) + } + + v +}); + +pub(super) fn get_upload_chunk_size() -> usize { + *UPLOAD_CHUNK_SIZE +} + pub trait GetSize { fn size(&self) -> u64; } diff --git a/crates/polars-io/src/utils/file.rs b/crates/polars-io/src/utils/file.rs new file mode 100644 index 000000000000..465fa735ab18 --- /dev/null +++ b/crates/polars-io/src/utils/file.rs @@ -0,0 +1,81 @@ +use std::io::Write; + +use polars_core::config; +use polars_error::{feature_gated, PolarsError, PolarsResult}; +use polars_utils::mmap::ensure_not_mapped; + +use crate::cloud::CloudOptions; +use crate::{is_cloud_url, resolve_homedir}; + +/// Open a path for writing. Supports cloud paths. +pub fn try_get_writeable( + path: &str, + #[cfg_attr(not(feature = "cloud"), allow(unused))] cloud_options: Option<&CloudOptions>, +) -> PolarsResult> { + let is_cloud = is_cloud_url(path); + let verbose = config::verbose(); + + if is_cloud { + feature_gated!("cloud", { + use crate::cloud::CloudWriter; + + if verbose { + eprintln!("try_get_writeable: cloud: {}", path) + } + + if path.starts_with("file://") { + std::fs::File::create(&path[const { "file://".len() }..]) + .map_err(PolarsError::from)?; + } + + let writer = crate::pl_async::get_runtime() + .block_on_potential_spawn(CloudWriter::new(path, cloud_options))?; + Ok(Box::new(writer)) + }) + } else if config::force_async() { + feature_gated!("cloud", { + use crate::cloud::CloudWriter; + + let path = resolve_homedir(&path); + + if verbose { + eprintln!( + "try_get_writeable: forced async: {}", + path.to_str().unwrap() + ) + } + + std::fs::File::create(&path).map_err(PolarsError::from)?; + let path = std::fs::canonicalize(&path)?; + + ensure_not_mapped(&path.metadata()?)?; + + let path = format!( + "file://{}", + if cfg!(target_family = "windows") { + path.to_str().unwrap().strip_prefix(r#"\\?\"#).unwrap() + } else { + path.to_str().unwrap() + } + ); + + if verbose { + eprintln!("try_get_writeable: forced async converted path: {}", path) + } + + let writer = crate::pl_async::get_runtime() + .block_on_potential_spawn(CloudWriter::new(&path, cloud_options))?; + Ok(Box::new(writer)) + }) + } else { + let path = resolve_homedir(&path); + std::fs::File::create(&path).map_err(PolarsError::from)?; + let path = std::fs::canonicalize(&path)?; + + if verbose { + eprintln!("try_get_writeable: local: {}", path.to_str().unwrap()) + } + + Ok(Box::new(polars_utils::open_file_write(&path)?)) + } +} diff --git a/crates/polars-io/src/utils/mod.rs b/crates/polars-io/src/utils/mod.rs index 8cae1ab1ef5f..dbba137930ef 100644 --- a/crates/polars-io/src/utils/mod.rs +++ b/crates/polars-io/src/utils/mod.rs @@ -4,6 +4,7 @@ mod other; pub use other::*; #[cfg(feature = "cloud")] pub mod byte_source; +pub mod file; pub mod slice; pub const URL_ENCODE_CHAR_SET: &percent_encoding::AsciiSet = &percent_encoding::CONTROLS diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index 617bea135acb..7ee9c4b88aa7 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -757,11 +757,13 @@ impl LazyFrame { self, path: impl AsRef, options: ParquetWriteOptions, + cloud_options: Option, ) -> PolarsResult<()> { self.sink( SinkType::File { path: Arc::new(path.as_ref().to_path_buf()), file_type: FileType::Parquet(options), + cloud_options, }, "collect().write_parquet()", ) @@ -792,11 +794,17 @@ impl LazyFrame { /// into memory. This methods will return an error if the query cannot be completely done in a /// streaming fashion. #[cfg(feature = "ipc")] - pub fn sink_ipc(self, path: impl AsRef, options: IpcWriterOptions) -> PolarsResult<()> { + pub fn sink_ipc( + self, + path: impl AsRef, + options: IpcWriterOptions, + cloud_options: Option, + ) -> PolarsResult<()> { self.sink( SinkType::File { path: Arc::new(path.as_ref().to_path_buf()), file_type: FileType::Ipc(options), + cloud_options, }, "collect().write_ipc()", ) @@ -837,11 +845,17 @@ impl LazyFrame { /// into memory. This methods will return an error if the query cannot be completely done in a /// streaming fashion. #[cfg(feature = "csv")] - pub fn sink_csv(self, path: impl AsRef, options: CsvWriterOptions) -> PolarsResult<()> { + pub fn sink_csv( + self, + path: impl AsRef, + options: CsvWriterOptions, + cloud_options: Option, + ) -> PolarsResult<()> { self.sink( SinkType::File { path: Arc::new(path.as_ref().to_path_buf()), file_type: FileType::Csv(options), + cloud_options, }, "collect().write_csv()", ) @@ -851,11 +865,17 @@ impl LazyFrame { /// into memory. This methods will return an error if the query cannot be completely done in a /// streaming fashion. #[cfg(feature = "json")] - pub fn sink_json(self, path: impl AsRef, options: JsonWriterOptions) -> PolarsResult<()> { + pub fn sink_json( + self, + path: impl AsRef, + options: JsonWriterOptions, + cloud_options: Option, + ) -> PolarsResult<()> { self.sink( SinkType::File { path: Arc::new(path.as_ref().to_path_buf()), file_type: FileType::Json(options), + cloud_options, }, "collect().write_ndjson()` or `collect().write_json()", ) diff --git a/crates/polars-pipe/src/executors/sinks/output/csv.rs b/crates/polars-pipe/src/executors/sinks/output/csv.rs index 773287b834b1..68859fa4654a 100644 --- a/crates/polars-pipe/src/executors/sinks/output/csv.rs +++ b/crates/polars-pipe/src/executors/sinks/output/csv.rs @@ -2,7 +2,9 @@ use std::path::Path; use crossbeam_channel::bounded; use polars_core::prelude::*; +use polars_io::cloud::CloudOptions; use polars_io::csv::write::{CsvWriter, CsvWriterOptions}; +use polars_io::utils::file::try_get_writeable; use polars_io::SerWriter; use crate::executors::sinks::output::file_sink::{init_writer_thread, FilesSink, SinkWriter}; @@ -11,9 +13,13 @@ use crate::pipeline::morsels_per_sink; pub struct CsvSink {} impl CsvSink { #[allow(clippy::new_ret_no_self)] - pub fn new(path: &Path, options: CsvWriterOptions, schema: &Schema) -> PolarsResult { - let file = std::fs::File::create(path)?; - let writer = CsvWriter::new(file) + pub fn new( + path: &Path, + options: CsvWriterOptions, + schema: &Schema, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + let writer = CsvWriter::new(try_get_writeable(path.to_str().unwrap(), cloud_options)?) .include_bom(options.include_bom) .include_header(options.include_header) .with_separator(options.serialize_options.separator) @@ -30,7 +36,7 @@ impl CsvSink { .n_threads(1) .batched(schema)?; - let writer = Box::new(writer) as Box; + let writer = Box::new(writer) as Box; let morsels_per_sink = morsels_per_sink(); let backpressure = morsels_per_sink * 2; @@ -50,7 +56,7 @@ impl CsvSink { } } -impl SinkWriter for polars_io::csv::write::BatchedWriter { +impl SinkWriter for polars_io::csv::write::BatchedWriter { fn _write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> { self.write_batch(df) } diff --git a/crates/polars-pipe/src/executors/sinks/output/ipc.rs b/crates/polars-pipe/src/executors/sinks/output/ipc.rs index 06b27ba9811d..858093b878f8 100644 --- a/crates/polars-pipe/src/executors/sinks/output/ipc.rs +++ b/crates/polars-pipe/src/executors/sinks/output/ipc.rs @@ -1,6 +1,8 @@ use std::path::Path; +use cloud::CloudOptions; use crossbeam_channel::bounded; +use file::try_get_writeable; use polars_core::prelude::*; use polars_io::ipc::IpcWriterOptions; use polars_io::prelude::*; @@ -11,9 +13,13 @@ use crate::pipeline::morsels_per_sink; pub struct IpcSink {} impl IpcSink { #[allow(clippy::new_ret_no_self)] - pub fn new(path: &Path, options: IpcWriterOptions, schema: &Schema) -> PolarsResult { - let file = std::fs::File::create(path)?; - let writer = IpcWriter::new(file) + pub fn new( + path: &Path, + options: IpcWriterOptions, + schema: &Schema, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + let writer = IpcWriter::new(try_get_writeable(path.to_str().unwrap(), cloud_options)?) .with_compression(options.compression) .batched(schema)?; diff --git a/crates/polars-pipe/src/executors/sinks/output/json.rs b/crates/polars-pipe/src/executors/sinks/output/json.rs index 7f76310b4c65..c0cab43e067a 100644 --- a/crates/polars-pipe/src/executors/sinks/output/json.rs +++ b/crates/polars-pipe/src/executors/sinks/output/json.rs @@ -2,12 +2,14 @@ use std::path::Path; use crossbeam_channel::bounded; use polars_core::prelude::*; +use polars_io::cloud::CloudOptions; use polars_io::json::{BatchedWriter, JsonWriterOptions}; +use polars_io::utils::file::try_get_writeable; use crate::executors::sinks::output::file_sink::{init_writer_thread, FilesSink, SinkWriter}; use crate::pipeline::morsels_per_sink; -impl SinkWriter for BatchedWriter { +impl SinkWriter for BatchedWriter { fn _write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> { self.write_batch(df) } @@ -24,11 +26,10 @@ impl JsonSink { path: &Path, options: JsonWriterOptions, _schema: &Schema, + cloud_options: Option<&CloudOptions>, ) -> PolarsResult { - let file = std::fs::File::create(path)?; - let writer = BatchedWriter::new(file); - - let writer = Box::new(writer) as Box; + let writer = BatchedWriter::new(try_get_writeable(path.to_str().unwrap(), cloud_options)?); + let writer = Box::new(writer) as Box; let morsels_per_sink = morsels_per_sink(); let backpressure = morsels_per_sink * 2; diff --git a/crates/polars-pipe/src/executors/sinks/output/parquet.rs b/crates/polars-pipe/src/executors/sinks/output/parquet.rs index 2291b1e21fcd..dad4d7c7fc89 100644 --- a/crates/polars-pipe/src/executors/sinks/output/parquet.rs +++ b/crates/polars-pipe/src/executors/sinks/output/parquet.rs @@ -4,9 +4,11 @@ use std::thread::JoinHandle; use crossbeam_channel::{bounded, Receiver, Sender}; use polars_core::prelude::*; +use polars_io::cloud::CloudOptions; use polars_io::parquet::write::{ BatchedWriter, ParquetWriteOptions, ParquetWriter, RowGroupIterColumns, }; +use polars_io::utils::file::try_get_writeable; use crate::executors::sinks::output::file_sink::{init_writer_thread, FilesSink, SinkWriter}; use crate::operators::{DataChunk, FinalizedSink, PExecutionContext, Sink, SinkResult}; @@ -14,14 +16,17 @@ use crate::pipeline::morsels_per_sink; type RowGroups = Vec>; -pub(super) fn init_row_group_writer_thread( +pub(super) fn init_row_group_writer_thread( receiver: Receiver>, - writer: Arc>, + writer: Arc>, // this is used to determine when a batch of chunks should be written to disk // all chunks per push should be collected to determine in which order they should // be written morsels_per_sink: usize, -) -> JoinHandle<()> { +) -> JoinHandle<()> +where + W: std::io::Write + Send + 'static, +{ std::thread::spawn(move || { // keep chunks around until all chunks per sink are written // then we write them all at once. @@ -53,15 +58,19 @@ pub(super) fn init_row_group_writer_thread( #[derive(Clone)] pub struct ParquetSink { - writer: Arc>, + writer: Arc>>, io_thread_handle: Arc>>, sender: Sender>, } impl ParquetSink { #[allow(clippy::new_ret_no_self)] - pub fn new(path: &Path, options: ParquetWriteOptions, schema: &Schema) -> PolarsResult { - let file = std::fs::File::create(path)?; - let writer = ParquetWriter::new(file) + pub fn new( + path: &Path, + options: ParquetWriteOptions, + schema: &Schema, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + let writer = ParquetWriter::new(try_get_writeable(path.to_str().unwrap(), cloud_options)?) .with_compression(options.compression) .with_data_page_size(options.data_page_size) .with_statistics(options.statistics) diff --git a/crates/polars-pipe/src/pipeline/convert.rs b/crates/polars-pipe/src/pipeline/convert.rs index b0b19aa26708..6dc62aa4e2cf 100644 --- a/crates/polars-pipe/src/pipeline/convert.rs +++ b/crates/polars-pipe/src/pipeline/convert.rs @@ -189,28 +189,28 @@ where }, #[allow(unused_variables)] SinkType::File { - path, file_type, .. + path, file_type, cloud_options } => { let path = path.as_ref().as_path(); match &file_type { #[cfg(feature = "parquet")] FileType::Parquet(options) => { - Box::new(ParquetSink::new(path, *options, input_schema.as_ref())?) + Box::new(ParquetSink::new(path, *options, input_schema.as_ref(), cloud_options.as_ref())?) as Box }, #[cfg(feature = "ipc")] FileType::Ipc(options) => { - Box::new(IpcSink::new(path, *options, input_schema.as_ref())?) + Box::new(IpcSink::new(path, *options, input_schema.as_ref(), cloud_options.as_ref())?) as Box }, #[cfg(feature = "csv")] FileType::Csv(options) => { - Box::new(CsvSink::new(path, options.clone(), input_schema.as_ref())?) + Box::new(CsvSink::new(path, options.clone(), input_schema.as_ref(), cloud_options.as_ref())?) as Box }, #[cfg(feature = "json")] FileType::Json(options) => { - Box::new(JsonSink::new(path, *options, input_schema.as_ref())?) + Box::new(JsonSink::new(path, *options, input_schema.as_ref(), cloud_options.as_ref())?) as Box }, #[allow(unreachable_patterns)] diff --git a/crates/polars-plan/src/plans/options.rs b/crates/polars-plan/src/plans/options.rs index 8d33d031dfa2..2d7479cefc50 100644 --- a/crates/polars-plan/src/plans/options.rs +++ b/crates/polars-plan/src/plans/options.rs @@ -295,6 +295,7 @@ pub enum SinkType { File { path: Arc, file_type: FileType, + cloud_options: Option, }, #[cfg(feature = "cloud")] Cloud { diff --git a/crates/polars-python/src/dataframe/io.rs b/crates/polars-python/src/dataframe/io.rs index d32a2c11ba8a..bd1015f3ff62 100644 --- a/crates/polars-python/src/dataframe/io.rs +++ b/crates/polars-python/src/dataframe/io.rs @@ -1,14 +1,15 @@ +use std::borrow::Cow; use std::io::BufWriter; use std::num::NonZeroUsize; use std::sync::Arc; +use cloud::credential_provider::PlCredentialProvider; #[cfg(feature = "avro")] use polars::io::avro::AvroCompression; use polars::io::RowIndex; use polars::prelude::*; #[cfg(feature = "parquet")] use polars_parquet::arrow::write::StatisticsOptions; -use polars_utils::mmap::ensure_not_mapped; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; @@ -21,7 +22,7 @@ use crate::file::{ get_either_file, get_file_like, get_mmap_bytes_reader, get_mmap_bytes_reader_and_path, read_if_bytesio, EitherRustPythonFile, }; -use crate::prelude::PyCompatLevel; +use crate::prelude::{parse_cloud_options, PyCompatLevel}; #[pymethods] impl PyDataFrame { @@ -154,6 +155,7 @@ impl PyDataFrame { name: name.into(), offset, }); + let result = match get_either_file(py_f, false)? { Py(f) => { let buf = f.as_buffer(); @@ -188,7 +190,7 @@ impl PyDataFrame { #[staticmethod] #[cfg(feature = "json")] - #[pyo3(signature = (py_f, infer_schema_length=None, schema=None, schema_overrides=None))] + #[pyo3(signature = (py_f, infer_schema_length, schema, schema_overrides))] pub fn read_json( py: Python, mut py_f: Bound, @@ -221,7 +223,7 @@ impl PyDataFrame { #[staticmethod] #[cfg(feature = "json")] - #[pyo3(signature = (py_f, ignore_errors, schema=None, schema_overrides=None))] + #[pyo3(signature = (py_f, ignore_errors, schema, schema_overrides))] pub fn read_ndjson( py: Python, mut py_f: Bound, @@ -339,7 +341,11 @@ impl PyDataFrame { } #[cfg(feature = "csv")] - #[pyo3(signature = (py_f, include_bom, include_header, separator, line_terminator, quote_char, batch_size, datetime_format=None, date_format=None, time_format=None, float_scientific=None, float_precision=None, null_value=None, quote_style=None))] + #[pyo3(signature = ( + py_f, include_bom, include_header, separator, line_terminator, quote_char, batch_size, + datetime_format, date_format, time_format, float_scientific, float_precision, null_value, + quote_style, cloud_options, credential_provider, retries + ))] pub fn write_csv( &mut self, py: Python, @@ -357,11 +363,29 @@ impl PyDataFrame { float_precision: Option, null_value: Option, quote_style: Option>, + cloud_options: Option>, + credential_provider: Option, + retries: usize, ) -> PyResult<()> { let null = null_value.unwrap_or_default(); - let mut buf = get_file_like(py_f, true)?; + + let cloud_options = if let Ok(path) = py_f.extract::>(py) { + let cloud_options = parse_cloud_options(&path, cloud_options.unwrap_or_default())?; + Some( + cloud_options + .with_max_retries(retries) + .with_credential_provider( + credential_provider.map(PlCredentialProvider::from_python_func_object), + ), + ) + } else { + None + }; + + let f = crate::file::try_get_writeable(py_f, cloud_options.as_ref())?; + py.allow_threads(|| { - CsvWriter::new(&mut buf) + CsvWriter::new(f) .include_bom(include_bom) .include_header(include_header) .with_separator(separator) @@ -382,7 +406,10 @@ impl PyDataFrame { } #[cfg(feature = "parquet")] - #[pyo3(signature = (py_f, compression, compression_level, statistics, row_group_size, data_page_size, partition_by, partition_chunk_size_bytes))] + #[pyo3(signature = ( + py_f, compression, compression_level, statistics, row_group_size, data_page_size, + partition_by, partition_chunk_size_bytes, cloud_options, credential_provider, retries + ))] pub fn write_parquet( &mut self, py: Python, @@ -394,12 +421,16 @@ impl PyDataFrame { data_page_size: Option, partition_by: Option>, partition_chunk_size_bytes: usize, + cloud_options: Option>, + credential_provider: Option, + retries: usize, ) -> PyResult<()> { use polars_io::partition::write_partitioned_dataset; let compression = parse_parquet_compression(compression, compression_level)?; if let Some(partition_by) = partition_by { + // TODO: Support cloud let path = py_f.extract::(py)?; py.allow_threads(|| { @@ -423,9 +454,23 @@ impl PyDataFrame { return Ok(()); }; - let buf = get_file_like(py_f, true)?; + let cloud_options = if let Ok(path) = py_f.extract::>(py) { + let cloud_options = parse_cloud_options(&path, cloud_options.unwrap_or_default())?; + Some( + cloud_options + .with_max_retries(retries) + .with_credential_provider( + credential_provider.map(PlCredentialProvider::from_python_func_object), + ), + ) + } else { + None + }; + + let f = crate::file::try_get_writeable(py_f, cloud_options.as_ref())?; + py.allow_threads(|| { - ParquetWriter::new(BufWriter::new(buf)) + ParquetWriter::new(BufWriter::new(f)) .with_compression(compression) .with_statistics(statistics.0) .with_row_group_size(row_group_size) @@ -440,6 +485,8 @@ impl PyDataFrame { pub fn write_json(&mut self, py_f: PyObject) -> PyResult<()> { let file = BufWriter::new(get_file_like(py_f, true)?); + // TODO: Cloud support + JsonWriter::new(file) .with_json_format(JsonFormat::Json) .finish(&mut self.df) @@ -451,6 +498,8 @@ impl PyDataFrame { pub fn write_ndjson(&mut self, py_f: PyObject) -> PyResult<()> { let file = BufWriter::new(get_file_like(py_f, true)?); + // TODO: Cloud support + JsonWriter::new(file) .with_json_format(JsonFormat::JsonLines) .finish(&mut self.df) @@ -460,20 +509,36 @@ impl PyDataFrame { } #[cfg(feature = "ipc")] + #[pyo3(signature = ( + py_f, compression, compat_level, cloud_options, credential_provider, retries + ))] pub fn write_ipc( &mut self, py: Python, py_f: PyObject, compression: Wrap>, compat_level: PyCompatLevel, + cloud_options: Option>, + credential_provider: Option, + retries: usize, ) -> PyResult<()> { - let either = get_either_file(py_f, true)?; - if let EitherRustPythonFile::Rust(ref f) = either { - ensure_not_mapped(f).map_err(PyPolarsErr::from)?; - } - let mut buf = either.into_dyn(); + let cloud_options = if let Ok(path) = py_f.extract::>(py) { + let cloud_options = parse_cloud_options(&path, cloud_options.unwrap_or_default())?; + Some( + cloud_options + .with_max_retries(retries) + .with_credential_provider( + credential_provider.map(PlCredentialProvider::from_python_func_object), + ), + ) + } else { + None + }; + + let f = crate::file::try_get_writeable(py_f, cloud_options.as_ref())?; + py.allow_threads(|| { - IpcWriter::new(&mut buf) + IpcWriter::new(f) .with_compression(compression.0) .with_compat_level(compat_level.0) .finish(&mut self.df) diff --git a/crates/polars-python/src/file.rs b/crates/polars-python/src/file.rs index 6bc91de65f21..45a61e777e28 100644 --- a/crates/polars-python/src/file.rs +++ b/crates/polars-python/src/file.rs @@ -10,6 +10,7 @@ use std::path::PathBuf; use polars::io::mmap::MmapBytesReader; use polars_error::polars_err; +use polars_io::cloud::CloudOptions; use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyString, PyStringMethods}; @@ -208,6 +209,20 @@ impl EitherRustPythonFile { EitherRustPythonFile::Rust(f) => Box::new(f), } } + + fn into_scan_source_input(self) -> PythonScanSourceInput { + match self { + EitherRustPythonFile::Py(f) => PythonScanSourceInput::Buffer(f.as_bytes()), + EitherRustPythonFile::Rust(f) => PythonScanSourceInput::File(f), + } + } + + pub fn into_dyn_writeable(self) -> Box { + match self { + EitherRustPythonFile::Py(f) => Box::new(f), + EitherRustPythonFile::Rust(f) => Box::new(f), + } + } } pub enum PythonScanSourceInput { @@ -216,6 +231,98 @@ pub enum PythonScanSourceInput { File(File), } +fn try_get_pyfile( + py: Python, + py_f: Bound<'_, PyAny>, + write: bool, +) -> PyResult<(EitherRustPythonFile, Option)> { + let io = py.import_bound("io").unwrap(); + let is_utf8_encoding = |py_f: &Bound| -> PyResult { + let encoding = py_f.getattr("encoding")?; + let encoding = encoding.extract::>()?; + Ok(encoding.eq_ignore_ascii_case("utf-8") || encoding.eq_ignore_ascii_case("utf8")) + }; + + #[cfg(target_family = "unix")] + if let Some(fd) = (py_f.is_exact_instance(&io.getattr("FileIO").unwrap()) + || (py_f.is_exact_instance(&io.getattr("BufferedReader").unwrap()) + || py_f.is_exact_instance(&io.getattr("BufferedWriter").unwrap()) + || py_f.is_exact_instance(&io.getattr("BufferedRandom").unwrap()) + || py_f.is_exact_instance(&io.getattr("BufferedRWPair").unwrap()) + || (py_f.is_exact_instance(&io.getattr("TextIOWrapper").unwrap()) + && is_utf8_encoding(&py_f)?)) + && if write { + // invalidate read buffer + py_f.call_method0("flush").is_ok() + } else { + // flush write buffer + py_f.call_method1("seek", (0, 1)).is_ok() + }) + .then(|| { + py_f.getattr("fileno") + .and_then(|fileno| fileno.call0()) + .and_then(|fileno| fileno.extract::()) + .ok() + }) + .flatten() + .map(|fileno| unsafe { + // `File::from_raw_fd()` takes the ownership of the file descriptor. + // When the File is dropped, it closes the file descriptor. + // This is undesired - the Python file object will become invalid. + // Therefore, we duplicate the file descriptor here. + // Closing the duplicated file descriptor will not close + // the original file descriptor; + // and the status, e.g. stream position, is still shared with + // the original file descriptor. + // We use `F_DUPFD_CLOEXEC` here instead of `dup()` + // because it also sets the `O_CLOEXEC` flag on the duplicated file descriptor, + // which `dup()` clears. + // `open()` in both Rust and Python automatically set `O_CLOEXEC` flag; + // it prevents leaking file descriptors across processes, + // and we want to be consistent with them. + // `F_DUPFD_CLOEXEC` is defined in POSIX.1-2008 + // and is present on all alive UNIX(-like) systems. + libc::fcntl(fileno, libc::F_DUPFD_CLOEXEC, 0) + }) + .filter(|fileno| *fileno != -1) + .map(|fileno| fileno as RawFd) + { + return Ok(( + EitherRustPythonFile::Rust(unsafe { File::from_raw_fd(fd) }), + // This works on Linux and BSD with procfs mounted, + // otherwise it fails silently. + fs::canonicalize(format!("/proc/self/fd/{fd}")).ok(), + )); + } + + // Unwrap TextIOWrapper + // Allow subclasses to allow things like pytest.capture.CaptureIO + let py_f = if py_f + .is_instance(&io.getattr("TextIOWrapper").unwrap()) + .unwrap_or_default() + { + if !is_utf8_encoding(&py_f)? { + return Err(PyPolarsErr::from( + polars_err!(InvalidOperation: "file encoding is not UTF-8"), + ) + .into()); + } + // XXX: we have to clear buffer here. + // Is there a better solution? + if write { + py_f.call_method0("flush")?; + } else { + py_f.call_method1("seek", (0, 1))?; + } + py_f.getattr("buffer")? + } else { + py_f + }; + PyFileLikeObject::ensure_requirements(&py_f, !write, write, !write)?; + let f = PyFileLikeObject::new(py_f.to_object(py)); + Ok((EitherRustPythonFile::Py(f), None)) +} + pub fn get_python_scan_source_input( py_f: PyObject, write: bool, @@ -231,93 +338,10 @@ pub fn get_python_scan_source_input( } if let Ok(s) = py_f.extract::>() { - let file_path = std::path::Path::new(&*s); - let file_path = resolve_homedir(file_path); + let file_path = resolve_homedir(&&*s); Ok(PythonScanSourceInput::Path(file_path)) } else { - let io = py.import_bound("io").unwrap(); - let is_utf8_encoding = |py_f: &Bound| -> PyResult { - let encoding = py_f.getattr("encoding")?; - let encoding = encoding.extract::>()?; - Ok(encoding.eq_ignore_ascii_case("utf-8") || encoding.eq_ignore_ascii_case("utf8")) - }; - - #[cfg(target_family = "unix")] - if let Some(fd) = (py_f.is_exact_instance(&io.getattr("FileIO").unwrap()) - || (py_f.is_exact_instance(&io.getattr("BufferedReader").unwrap()) - || py_f.is_exact_instance(&io.getattr("BufferedWriter").unwrap()) - || py_f.is_exact_instance(&io.getattr("BufferedRandom").unwrap()) - || py_f.is_exact_instance(&io.getattr("BufferedRWPair").unwrap()) - || (py_f.is_exact_instance(&io.getattr("TextIOWrapper").unwrap()) - && is_utf8_encoding(&py_f)?)) - && if write { - // invalidate read buffer - py_f.call_method0("flush").is_ok() - } else { - // flush write buffer - py_f.call_method1("seek", (0, 1)).is_ok() - }) - .then(|| { - py_f.getattr("fileno") - .and_then(|fileno| fileno.call0()) - .and_then(|fileno| fileno.extract::()) - .ok() - }) - .flatten() - .map(|fileno| unsafe { - // `File::from_raw_fd()` takes the ownership of the file descriptor. - // When the File is dropped, it closes the file descriptor. - // This is undesired - the Python file object will become invalid. - // Therefore, we duplicate the file descriptor here. - // Closing the duplicated file descriptor will not close - // the original file descriptor; - // and the status, e.g. stream position, is still shared with - // the original file descriptor. - // We use `F_DUPFD_CLOEXEC` here instead of `dup()` - // because it also sets the `O_CLOEXEC` flag on the duplicated file descriptor, - // which `dup()` clears. - // `open()` in both Rust and Python automatically set `O_CLOEXEC` flag; - // it prevents leaking file descriptors across processes, - // and we want to be consistent with them. - // `F_DUPFD_CLOEXEC` is defined in POSIX.1-2008 - // and is present on all alive UNIX(-like) systems. - libc::fcntl(fileno, libc::F_DUPFD_CLOEXEC, 0) - }) - .filter(|fileno| *fileno != -1) - .map(|fileno| fileno as RawFd) - { - return Ok(PythonScanSourceInput::File(unsafe { - File::from_raw_fd(fd) - })); - } - - // Unwrap TextIOWrapper - // Allow subclasses to allow things like pytest.capture.CaptureIO - let py_f = if py_f - .is_instance(&io.getattr("TextIOWrapper").unwrap()) - .unwrap_or_default() - { - if !is_utf8_encoding(&py_f)? { - return Err(PyPolarsErr::from( - polars_err!(InvalidOperation: "file encoding is not UTF-8"), - ) - .into()); - } - // XXX: we have to clear buffer here. - // Is there a better solution? - if write { - py_f.call_method0("flush")?; - } else { - py_f.call_method1("seek", (0, 1))?; - } - py_f.getattr("buffer")? - } else { - py_f - }; - PyFileLikeObject::ensure_requirements(&py_f, !write, write, !write)?; - Ok(PythonScanSourceInput::Buffer( - PyFileLikeObject::new(py_f.to_object(py)).as_bytes(), - )) + Ok(try_get_pyfile(py, py_f, write)?.0.into_scan_source_input()) } }) } @@ -329,8 +353,7 @@ fn get_either_buffer_or_path( Python::with_gil(|py| { let py_f = py_f.into_bound(py); if let Ok(s) = py_f.extract::>() { - let file_path = std::path::Path::new(&*s); - let file_path = resolve_homedir(file_path); + let file_path = resolve_homedir(&&*s); let f = if write { File::create(&file_path)? } else { @@ -338,90 +361,7 @@ fn get_either_buffer_or_path( }; Ok((EitherRustPythonFile::Rust(f), Some(file_path))) } else { - let io = py.import_bound("io").unwrap(); - let is_utf8_encoding = |py_f: &Bound| -> PyResult { - let encoding = py_f.getattr("encoding")?; - let encoding = encoding.extract::>()?; - Ok(encoding.eq_ignore_ascii_case("utf-8") || encoding.eq_ignore_ascii_case("utf8")) - }; - #[cfg(target_family = "unix")] - if let Some(fd) = (py_f.is_exact_instance(&io.getattr("FileIO").unwrap()) - || (py_f.is_exact_instance(&io.getattr("BufferedReader").unwrap()) - || py_f.is_exact_instance(&io.getattr("BufferedWriter").unwrap()) - || py_f.is_exact_instance(&io.getattr("BufferedRandom").unwrap()) - || py_f.is_exact_instance(&io.getattr("BufferedRWPair").unwrap()) - || (py_f.is_exact_instance(&io.getattr("TextIOWrapper").unwrap()) - && is_utf8_encoding(&py_f)?)) - && if write { - // invalidate read buffer - py_f.call_method0("flush").is_ok() - } else { - // flush write buffer - py_f.call_method1("seek", (0, 1)).is_ok() - }) - .then(|| { - py_f.getattr("fileno") - .and_then(|fileno| fileno.call0()) - .and_then(|fileno| fileno.extract::()) - .ok() - }) - .flatten() - .map(|fileno| unsafe { - // `File::from_raw_fd()` takes the ownership of the file descriptor. - // When the File is dropped, it closes the file descriptor. - // This is undesired - the Python file object will become invalid. - // Therefore, we duplicate the file descriptor here. - // Closing the duplicated file descriptor will not close - // the original file descriptor; - // and the status, e.g. stream position, is still shared with - // the original file descriptor. - // We use `F_DUPFD_CLOEXEC` here instead of `dup()` - // because it also sets the `O_CLOEXEC` flag on the duplicated file descriptor, - // which `dup()` clears. - // `open()` in both Rust and Python automatically set `O_CLOEXEC` flag; - // it prevents leaking file descriptors across processes, - // and we want to be consistent with them. - // `F_DUPFD_CLOEXEC` is defined in POSIX.1-2008 - // and is present on all alive UNIX(-like) systems. - libc::fcntl(fileno, libc::F_DUPFD_CLOEXEC, 0) - }) - .filter(|fileno| *fileno != -1) - .map(|fileno| fileno as RawFd) - { - return Ok(( - EitherRustPythonFile::Rust(unsafe { File::from_raw_fd(fd) }), - // This works on Linux and BSD with procfs mounted, - // otherwise it fails silently. - fs::canonicalize(format!("/proc/self/fd/{fd}")).ok(), - )); - } - - // Unwrap TextIOWrapper - // Allow subclasses to allow things like pytest.capture.CaptureIO - let py_f = if py_f - .is_instance(&io.getattr("TextIOWrapper").unwrap()) - .unwrap_or_default() - { - if !is_utf8_encoding(&py_f)? { - return Err(PyPolarsErr::from( - polars_err!(InvalidOperation: "file encoding is not UTF-8"), - ) - .into()); - } - // XXX: we have to clear buffer here. - // Is there a better solution? - if write { - py_f.call_method0("flush")?; - } else { - py_f.call_method1("seek", (0, 1))?; - } - py_f.getattr("buffer")? - } else { - py_f - }; - PyFileLikeObject::ensure_requirements(&py_f, !write, write, !write)?; - let f = PyFileLikeObject::new(py_f.to_object(py)); - Ok((EitherRustPythonFile::Py(f), None)) + try_get_pyfile(py, py_f, write) } }) } @@ -473,3 +413,20 @@ pub fn get_mmap_bytes_reader_and_path<'a>( } } } + +pub fn try_get_writeable( + py_f: PyObject, + cloud_options: Option<&CloudOptions>, +) -> PyResult> { + Python::with_gil(|py| { + let py_f = py_f.into_bound(py); + + if let Ok(s) = py_f.extract::>() { + polars::prelude::file::try_get_writeable(&s, cloud_options) + .map_err(PyPolarsErr::from) + .map_err(|e| e.into()) + } else { + Ok(try_get_pyfile(py, py_f, true)?.0.into_dyn_writeable()) + } + }) +} diff --git a/crates/polars-python/src/lazyframe/general.rs b/crates/polars-python/src/lazyframe/general.rs index 6a47a0e191a8..993b7a232b1b 100644 --- a/crates/polars-python/src/lazyframe/general.rs +++ b/crates/polars-python/src/lazyframe/general.rs @@ -671,7 +671,10 @@ impl PyLazyFrame { } #[cfg(all(feature = "streaming", feature = "parquet"))] - #[pyo3(signature = (path, compression, compression_level, statistics, row_group_size, data_page_size, maintain_order))] + #[pyo3(signature = ( + path, compression, compression_level, statistics, row_group_size, data_page_size, + maintain_order, cloud_options, credential_provider, retries + ))] fn sink_parquet( &self, py: Python, @@ -682,6 +685,9 @@ impl PyLazyFrame { row_group_size: Option, data_page_size: Option, maintain_order: bool, + cloud_options: Option>, + credential_provider: Option, + retries: usize, ) -> PyResult<()> { let compression = parse_parquet_compression(compression, compression_level)?; @@ -693,40 +699,73 @@ impl PyLazyFrame { maintain_order, }; + let cloud_options = { + let cloud_options = + parse_cloud_options(path.to_str().unwrap(), cloud_options.unwrap_or_default())?; + Some( + cloud_options + .with_max_retries(retries) + .with_credential_provider( + credential_provider.map(polars::prelude::cloud::credential_provider::PlCredentialProvider::from_python_func_object), + ), + ) + }; + // if we don't allow threads and we have udfs trying to acquire the gil from different // threads we deadlock. py.allow_threads(|| { let ldf = self.ldf.clone(); - ldf.sink_parquet(path, options).map_err(PyPolarsErr::from) + ldf.sink_parquet(path, options, cloud_options) + .map_err(PyPolarsErr::from) })?; Ok(()) } #[cfg(all(feature = "streaming", feature = "ipc"))] - #[pyo3(signature = (path, compression, maintain_order))] + #[pyo3(signature = (path, compression, maintain_order, cloud_options, credential_provider, retries))] fn sink_ipc( &self, py: Python, path: PathBuf, compression: Option>, maintain_order: bool, + cloud_options: Option>, + credential_provider: Option, + retries: usize, ) -> PyResult<()> { let options = IpcWriterOptions { compression: compression.map(|c| c.0), maintain_order, }; + let cloud_options = { + let cloud_options = + parse_cloud_options(path.to_str().unwrap(), cloud_options.unwrap_or_default())?; + Some( + cloud_options + .with_max_retries(retries) + .with_credential_provider( + credential_provider.map(polars::prelude::cloud::credential_provider::PlCredentialProvider::from_python_func_object), + ), + ) + }; + // if we don't allow threads and we have udfs trying to acquire the gil from different // threads we deadlock. py.allow_threads(|| { let ldf = self.ldf.clone(); - ldf.sink_ipc(path, options).map_err(PyPolarsErr::from) + ldf.sink_ipc(path, options, cloud_options) + .map_err(PyPolarsErr::from) })?; Ok(()) } #[cfg(all(feature = "streaming", feature = "csv"))] - #[pyo3(signature = (path, include_bom, include_header, separator, line_terminator, quote_char, batch_size, datetime_format, date_format, time_format, float_scientific, float_precision, null_value, quote_style, maintain_order))] + #[pyo3(signature = ( + path, include_bom, include_header, separator, line_terminator, quote_char, batch_size, + datetime_format, date_format, time_format, float_scientific, float_precision, null_value, + quote_style, maintain_order, cloud_options, credential_provider, retries + ))] fn sink_csv( &self, py: Python, @@ -745,6 +784,9 @@ impl PyLazyFrame { null_value: Option, quote_style: Option>, maintain_order: bool, + cloud_options: Option>, + credential_provider: Option, + retries: usize, ) -> PyResult<()> { let quote_style = quote_style.map_or(QuoteStyle::default(), |wrap| wrap.0); let null_value = null_value.unwrap_or(SerializeOptions::default().null); @@ -770,26 +812,60 @@ impl PyLazyFrame { serialize_options, }; + let cloud_options = { + let cloud_options = + parse_cloud_options(path.to_str().unwrap(), cloud_options.unwrap_or_default())?; + Some( + cloud_options + .with_max_retries(retries) + .with_credential_provider( + credential_provider.map(polars::prelude::cloud::credential_provider::PlCredentialProvider::from_python_func_object), + ), + ) + }; + // if we don't allow threads and we have udfs trying to acquire the gil from different // threads we deadlock. py.allow_threads(|| { let ldf = self.ldf.clone(); - ldf.sink_csv(path, options).map_err(PyPolarsErr::from) + ldf.sink_csv(path, options, cloud_options) + .map_err(PyPolarsErr::from) })?; Ok(()) } #[allow(clippy::too_many_arguments)] #[cfg(all(feature = "streaming", feature = "json"))] - #[pyo3(signature = (path, maintain_order))] - fn sink_json(&self, py: Python, path: PathBuf, maintain_order: bool) -> PyResult<()> { + #[pyo3(signature = (path, maintain_order, cloud_options, credential_provider, retries))] + fn sink_json( + &self, + py: Python, + path: PathBuf, + maintain_order: bool, + cloud_options: Option>, + credential_provider: Option, + retries: usize, + ) -> PyResult<()> { let options = JsonWriterOptions { maintain_order }; + let cloud_options = { + let cloud_options = + parse_cloud_options(path.to_str().unwrap(), cloud_options.unwrap_or_default())?; + Some( + cloud_options + .with_max_retries(retries) + .with_credential_provider( + credential_provider.map(polars::prelude::cloud::credential_provider::PlCredentialProvider::from_python_func_object), + ), + ) + }; + // if we don't allow threads and we have udfs trying to acquire the gil from different // threads we deadlock. py.allow_threads(|| { let ldf = self.ldf.clone(); - ldf.sink_json(path, options).map_err(PyPolarsErr::from) + ldf.sink_json(path, options, cloud_options) + .map_err(PyPolarsErr::from) })?; Ok(()) } diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index fcbec84a2e53..c9304ab8b635 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -210,7 +210,11 @@ pub fn lower_ir( let phys_input = lower_ir!(*input)?; PhysNodeKind::InMemorySink { input: phys_input } }, - SinkType::File { path, file_type } => { + SinkType::File { + path, + file_type, + cloud_options: _, + } => { let path = path.clone(); let file_type = file_type.clone(); diff --git a/crates/polars-utils/src/io.rs b/crates/polars-utils/src/io.rs index d472c4a8186d..cce2eafe22a7 100644 --- a/crates/polars-utils/src/io.rs +++ b/crates/polars-utils/src/io.rs @@ -23,6 +23,15 @@ pub fn open_file(path: &Path) -> PolarsResult { File::open(path).map_err(|err| _limit_path_len_io_err(path, err)) } +pub fn open_file_write(path: &Path) -> PolarsResult { + std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(path) + .map_err(|err| _limit_path_len_io_err(path, err)) +} + pub fn create_file(path: &Path) -> PolarsResult { File::create(path).map_err(|err| _limit_path_len_io_err(path, err)) } diff --git a/crates/polars-utils/src/mmap.rs b/crates/polars-utils/src/mmap.rs index ef07714d591f..52cd7f04b0a7 100644 --- a/crates/polars-utils/src/mmap.rs +++ b/crates/polars-utils/src/mmap.rs @@ -277,6 +277,8 @@ impl MMapSemaphore { #[cfg(target_family = "unix")] { + // FIXME: We aren't handling the case where the file is already open in write-mode here. + use std::os::unix::fs::MetadataExt; let metadata = file.metadata()?; @@ -324,13 +326,16 @@ impl Drop for MMapSemaphore { } } -pub fn ensure_not_mapped(#[allow(unused)] file: &File) -> PolarsResult<()> { +pub fn ensure_not_mapped( + #[cfg_attr(not(target_family = "unix"), allow(unused))] file_md: &std::fs::Metadata, +) -> PolarsResult<()> { + // TODO: We need to actually register that this file has been write-opened and prevent + // read-opening this file based on that. #[cfg(target_family = "unix")] { use std::os::unix::fs::MetadataExt; let guard = MEMORY_MAPPED_FILES.lock().unwrap(); - let metadata = file.metadata()?; - if guard.contains_key(&(metadata.dev(), metadata.ino())) { + if guard.contains_key(&(file_md.dev(), file_md.ino())) { polars_bail!(ComputeError: "cannot write to file: already memory mapped"); } } diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 49e1053e375d..fd79a0c88ccb 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -177,6 +177,7 @@ ) from polars._utils.various import NoDefault from polars.interchange.dataframe import PolarsDataFrame + from polars.io.cloud import CredentialProviderFunction from polars.ml.torch import PolarsDataset if sys.version_info >= (3, 10): @@ -2758,6 +2759,9 @@ def write_csv( float_precision: int | None = ..., null_value: str | None = ..., quote_style: CsvQuoteStyle | None = ..., + storage_options: dict[str, Any] | None = ..., + credential_provider: CredentialProviderFunction | Literal["auto"] | None = ..., + retries: int = ..., ) -> str: ... @overload @@ -2778,6 +2782,9 @@ def write_csv( float_precision: int | None = ..., null_value: str | None = ..., quote_style: CsvQuoteStyle | None = ..., + storage_options: dict[str, Any] | None = ..., + credential_provider: CredentialProviderFunction | Literal["auto"] | None = ..., + retries: int = ..., ) -> None: ... def write_csv( @@ -2797,6 +2804,11 @@ def write_csv( float_precision: int | None = None, null_value: str | None = None, quote_style: CsvQuoteStyle | None = None, + storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction + | Literal["auto"] + | None = "auto", + retries: int = 2, ) -> str | None: """ Write to comma-separated values (CSV) file. @@ -2856,6 +2868,30 @@ def write_csv( Namely, when writing a field that does not parse as a valid float or integer, then quotes will be used even if they aren`t strictly necessary. + storage_options + Options that indicate how to connect to a cloud provider. + + The cloud providers currently supported are AWS, GCP, and Azure. + See supported keys here: + + * `aws `_ + * `gcp `_ + * `azure `_ + * Hugging Face (`hf://`): Accepts an API key under the `token` parameter: \ + `{'token': '...'}`, or by setting the `HF_TOKEN` environment variable. + + If `storage_options` is not provided, Polars will try to infer the + information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + retries + Number of retries if accessing a cloud instance fails. Examples -------- @@ -2910,6 +2946,18 @@ def write_csv_to_string() -> str: elif isinstance(file, (str, os.PathLike)): file = normalize_filepath(file) + from polars.io.cloud.credential_provider import _maybe_init_credential_provider + + credential_provider = _maybe_init_credential_provider( + credential_provider, file, storage_options, "write_csv" + ) + + if storage_options: + storage_options = list(storage_options.items()) # type: ignore[assignment] + else: + # Handle empty dict input + storage_options = None + self._df.write_csv( file, include_bom, @@ -2925,6 +2973,9 @@ def write_csv_to_string() -> str: float_precision, null_value, quote_style, + cloud_options=storage_options, + credential_provider=credential_provider, + retries=retries, ) if should_return_buffer: @@ -3540,6 +3591,11 @@ def write_ipc( *, compression: IpcCompression = "uncompressed", compat_level: CompatLevel | None = None, + storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction + | Literal["auto"] + | None = "auto", + retries: int = 2, ) -> BytesIO: ... @overload @@ -3549,6 +3605,11 @@ def write_ipc( *, compression: IpcCompression = "uncompressed", compat_level: CompatLevel | None = None, + storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction + | Literal["auto"] + | None = "auto", + retries: int = 2, ) -> None: ... @deprecate_renamed_parameter("future", "compat_level", version="1.1") @@ -3558,6 +3619,11 @@ def write_ipc( *, compression: IpcCompression = "uncompressed", compat_level: CompatLevel | None = None, + storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction + | Literal["auto"] + | None = "auto", + retries: int = 2, ) -> BytesIO | None: """ Write to Arrow IPC binary stream or Feather file. @@ -3574,6 +3640,30 @@ def write_ipc( compat_level Use a specific compatibility level when exporting Polars' internal data structures. + storage_options + Options that indicate how to connect to a cloud provider. + + The cloud providers currently supported are AWS, GCP, and Azure. + See supported keys here: + + * `aws `_ + * `gcp `_ + * `azure `_ + * Hugging Face (`hf://`): Accepts an API key under the `token` parameter: \ + `{'token': '...'}`, or by setting the `HF_TOKEN` environment variable. + + If `storage_options` is not provided, Polars will try to infer the + information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + retries + Number of retries if accessing a cloud instance fails. Examples -------- @@ -3603,7 +3693,30 @@ def write_ipc( if compression is None: compression = "uncompressed" - self._df.write_ipc(file, compression, compat_level) + from polars.io.cloud.credential_provider import _maybe_init_credential_provider + + credential_provider = ( + None + if return_bytes + else _maybe_init_credential_provider( + credential_provider, file, storage_options, "write_ipc" + ) + ) + + if storage_options: + storage_options = list(storage_options.items()) # type: ignore[assignment] + else: + # Handle empty dict input + storage_options = None + + self._df.write_ipc( + file, + compression, + compat_level, + cloud_options=storage_options, + credential_provider=credential_provider, + retries=retries, + ) return file if return_bytes else None # type: ignore[return-value] @overload @@ -3692,6 +3805,11 @@ def write_parquet( pyarrow_options: dict[str, Any] | None = None, partition_by: str | Sequence[str] | None = None, partition_chunk_size_bytes: int = 4_294_967_296, + storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction + | Literal["auto"] + | None = "auto", + retries: int = 2, ) -> None: """ Write to Apache Parquet file. @@ -3752,6 +3870,30 @@ def write_parquet( writing. Note this is calculated using the size of the DataFrame in memory - the size of the output file may differ depending on the file format / compression. + storage_options + Options that indicate how to connect to a cloud provider. + + The cloud providers currently supported are AWS, GCP, and Azure. + See supported keys here: + + * `aws `_ + * `gcp `_ + * `azure `_ + * Hugging Face (`hf://`): Accepts an API key under the `token` parameter: \ + `{'token': '...'}`, or by setting the `HF_TOKEN` environment variable. + + If `storage_options` is not provided, Polars will try to infer the + information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + retries + Number of retries if accessing a cloud instance fails. Examples -------- @@ -3833,41 +3975,57 @@ def write_parquet( **(pyarrow_options or {}), ) + return + + from polars.io.cloud.credential_provider import _maybe_init_credential_provider + + credential_provider = _maybe_init_credential_provider( + credential_provider, file, storage_options, "write_parquet" + ) + + if storage_options: + storage_options = list(storage_options.items()) # type: ignore[assignment] else: - if isinstance(statistics, bool) and statistics: - statistics = { - "min": True, - "max": True, - "distinct_count": False, - "null_count": True, - } - elif isinstance(statistics, bool) and not statistics: - statistics = {} - elif statistics == "full": - statistics = { - "min": True, - "max": True, - "distinct_count": True, - "null_count": True, - } + # Handle empty dict input + storage_options = None + + if isinstance(statistics, bool) and statistics: + statistics = { + "min": True, + "max": True, + "distinct_count": False, + "null_count": True, + } + elif isinstance(statistics, bool) and not statistics: + statistics = {} + elif statistics == "full": + statistics = { + "min": True, + "max": True, + "distinct_count": True, + "null_count": True, + } - if partition_by is not None: - msg = "The `partition_by` parameter of `write_parquet` is considered unstable." - issue_unstable_warning(msg) - - if isinstance(partition_by, str): - partition_by = [partition_by] - - self._df.write_parquet( - file, - compression, - compression_level, - statistics, - row_group_size, - data_page_size, - partition_by=partition_by, - partition_chunk_size_bytes=partition_chunk_size_bytes, - ) + if partition_by is not None: + msg = "The `partition_by` parameter of `write_parquet` is considered unstable." + issue_unstable_warning(msg) + + if isinstance(partition_by, str): + partition_by = [partition_by] + + self._df.write_parquet( + file, + compression, + compression_level, + statistics, + row_group_size, + data_page_size, + partition_by=partition_by, + partition_chunk_size_bytes=partition_chunk_size_bytes, + cloud_options=storage_options, + credential_provider=credential_provider, + retries=retries, + ) def write_database( self, diff --git a/py-polars/polars/io/cloud/_utils.py b/py-polars/polars/io/cloud/_utils.py index 7279838aa005..91ad65a4ca94 100644 --- a/py-polars/polars/io/cloud/_utils.py +++ b/py-polars/polars/io/cloud/_utils.py @@ -16,7 +16,8 @@ def _first_scan_path( | list[Path] | list[IO[str]] | list[IO[bytes]] - | list[bytes], + | list[bytes] + | None, ) -> str | Path | None: if isinstance(source, (str, Path)): return source diff --git a/py-polars/polars/io/cloud/credential_provider.py b/py-polars/polars/io/cloud/credential_provider.py index 69f2bdbdf67a..dd98683f316e 100644 --- a/py-polars/polars/io/cloud/credential_provider.py +++ b/py-polars/polars/io/cloud/credential_provider.py @@ -229,7 +229,8 @@ def _maybe_init_credential_provider( | list[Path] | list[IO[str]] | list[IO[bytes]] - | list[bytes], + | list[bytes] + | None, storage_options: dict[str, Any] | None, caller_name: str, ) -> CredentialProviderFunction | CredentialProvider | None: diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 7419308fdd6f..77f50efd9d85 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -120,6 +120,7 @@ UniqueKeepStrategy, ) from polars.dependencies import numpy as np + from polars.io.cloud import CredentialProviderFunction if sys.version_info >= (3, 10): from typing import Concatenate, ParamSpec @@ -2259,6 +2260,11 @@ def sink_parquet( slice_pushdown: bool = True, collapse_joins: bool = True, no_optimization: bool = False, + storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction + | Literal["auto"] + | None = "auto", + retries: int = 2, ) -> None: """ Evaluate the query in streaming mode and write to a Parquet file. @@ -2326,6 +2332,30 @@ def sink_parquet( Collapse a join and filters into a faster join no_optimization Turn off (certain) optimizations. + storage_options + Options that indicate how to connect to a cloud provider. + + The cloud providers currently supported are AWS, GCP, and Azure. + See supported keys here: + + * `aws `_ + * `gcp `_ + * `azure `_ + * Hugging Face (`hf://`): Accepts an API key under the `token` parameter: \ + `{'token': '...'}`, or by setting the `HF_TOKEN` environment variable. + + If `storage_options` is not provided, Polars will try to infer the + information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + retries + Number of retries if accessing a cloud instance fails. Returns ------- @@ -2363,6 +2393,18 @@ def sink_parquet( "null_count": True, } + from polars.io.cloud.credential_provider import _maybe_init_credential_provider + + credential_provider = _maybe_init_credential_provider( + credential_provider, path, storage_options, "sink_parquet" + ) + + if storage_options: + storage_options = list(storage_options.items()) # type: ignore[assignment] + else: + # Handle empty dict input + storage_options = None + return lf.sink_parquet( path=normalize_filepath(path), compression=compression, @@ -2371,6 +2413,9 @@ def sink_parquet( row_group_size=row_group_size, data_page_size=data_page_size, maintain_order=maintain_order, + cloud_options=storage_options, + credential_provider=credential_provider, + retries=retries, ) @unstable() @@ -2387,6 +2432,11 @@ def sink_ipc( slice_pushdown: bool = True, collapse_joins: bool = True, no_optimization: bool = False, + storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction + | Literal["auto"] + | None = "auto", + retries: int = 2, ) -> None: """ Evaluate the query in streaming mode and write to an IPC file. @@ -2421,6 +2471,30 @@ def sink_ipc( Collapse a join and filters into a faster join no_optimization Turn off (certain) optimizations. + storage_options + Options that indicate how to connect to a cloud provider. + + The cloud providers currently supported are AWS, GCP, and Azure. + See supported keys here: + + * `aws `_ + * `gcp `_ + * `azure `_ + * Hugging Face (`hf://`): Accepts an API key under the `token` parameter: \ + `{'token': '...'}`, or by setting the `HF_TOKEN` environment variable. + + If `storage_options` is not provided, Polars will try to infer the + information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + retries + Number of retries if accessing a cloud instance fails. Returns ------- @@ -2441,10 +2515,25 @@ def sink_ipc( no_optimization=no_optimization, ) + from polars.io.cloud.credential_provider import _maybe_init_credential_provider + + credential_provider = _maybe_init_credential_provider( + credential_provider, path, storage_options, "sink_ipc" + ) + + if storage_options: + storage_options = list(storage_options.items()) # type: ignore[assignment] + else: + # Handle empty dict input + storage_options = None + return lf.sink_ipc( path=path, compression=compression, maintain_order=maintain_order, + cloud_options=storage_options, + credential_provider=credential_provider, + retries=retries, ) @unstable() @@ -2473,6 +2562,11 @@ def sink_csv( slice_pushdown: bool = True, collapse_joins: bool = True, no_optimization: bool = False, + storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction + | Literal["auto"] + | None = "auto", + retries: int = 2, ) -> None: """ Evaluate the query in streaming mode and write to a CSV file. @@ -2555,6 +2649,30 @@ def sink_csv( Collapse a join and filters into a faster join no_optimization Turn off (certain) optimizations. + storage_options + Options that indicate how to connect to a cloud provider. + + The cloud providers currently supported are AWS, GCP, and Azure. + See supported keys here: + + * `aws `_ + * `gcp `_ + * `azure `_ + * Hugging Face (`hf://`): Accepts an API key under the `token` parameter: \ + `{'token': '...'}`, or by setting the `HF_TOKEN` environment variable. + + If `storage_options` is not provided, Polars will try to infer the + information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + retries + Number of retries if accessing a cloud instance fails. Returns ------- @@ -2582,6 +2700,18 @@ def sink_csv( no_optimization=no_optimization, ) + from polars.io.cloud.credential_provider import _maybe_init_credential_provider + + credential_provider = _maybe_init_credential_provider( + credential_provider, path, storage_options, "sink_csv" + ) + + if storage_options: + storage_options = list(storage_options.items()) # type: ignore[assignment] + else: + # Handle empty dict input + storage_options = None + return lf.sink_csv( path=normalize_filepath(path), include_bom=include_bom, @@ -2598,6 +2728,9 @@ def sink_csv( null_value=null_value, quote_style=quote_style, maintain_order=maintain_order, + cloud_options=storage_options, + credential_provider=credential_provider, + retries=retries, ) @unstable() @@ -2613,6 +2746,11 @@ def sink_ndjson( slice_pushdown: bool = True, collapse_joins: bool = True, no_optimization: bool = False, + storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction + | Literal["auto"] + | None = "auto", + retries: int = 2, ) -> None: """ Evaluate the query in streaming mode and write to an NDJSON file. @@ -2644,6 +2782,30 @@ def sink_ndjson( Collapse a join and filters into a faster join no_optimization Turn off (certain) optimizations. + storage_options + Options that indicate how to connect to a cloud provider. + + The cloud providers currently supported are AWS, GCP, and Azure. + See supported keys here: + + * `aws `_ + * `gcp `_ + * `azure `_ + * Hugging Face (`hf://`): Accepts an API key under the `token` parameter: \ + `{'token': '...'}`, or by setting the `HF_TOKEN` environment variable. + + If `storage_options` is not provided, Polars will try to infer the + information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + retries + Number of retries if accessing a cloud instance fails. Returns ------- @@ -2664,7 +2826,25 @@ def sink_ndjson( no_optimization=no_optimization, ) - return lf.sink_json(path=path, maintain_order=maintain_order) + from polars.io.cloud.credential_provider import _maybe_init_credential_provider + + credential_provider = _maybe_init_credential_provider( + credential_provider, path, storage_options, "sink_ndjson" + ) + + if storage_options: + storage_options = list(storage_options.items()) # type: ignore[assignment] + else: + # Handle empty dict input + storage_options = None + + return lf.sink_json( + path=path, + maintain_order=maintain_order, + cloud_options=storage_options, + credential_provider=credential_provider, + retries=retries, + ) def _set_sink_optimizations( self, diff --git a/py-polars/tests/unit/io/test_write.py b/py-polars/tests/unit/io/test_write.py new file mode 100644 index 000000000000..e49577e08291 --- /dev/null +++ b/py-polars/tests/unit/io/test_write.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Callable + +import pytest + +import polars as pl +from polars.testing.asserts.frame import assert_frame_equal + +READ_WRITE_FUNC_PARAM = [ + (pl.read_parquet, pl.DataFrame.write_parquet), + (lambda *a: pl.scan_csv(*a).collect(), pl.DataFrame.write_csv), + (lambda *a: pl.scan_ipc(*a).collect(), pl.DataFrame.write_ipc), + # Sink + (pl.read_parquet, lambda df, path: pl.DataFrame.lazy(df).sink_parquet(path)), + ( + lambda *a: pl.scan_csv(*a).collect(), + lambda df, path: pl.DataFrame.lazy(df).sink_csv(path), + ), + ( + lambda *a: pl.scan_ipc(*a).collect(), + lambda df, path: pl.DataFrame.lazy(df).sink_ipc(path), + ), + ( + lambda *a: pl.scan_ndjson(*a).collect(), + lambda df, path: pl.DataFrame.lazy(df).sink_ndjson(path), + ), +] + + +@pytest.mark.parametrize( + ("read_func", "write_func"), + READ_WRITE_FUNC_PARAM, +) +@pytest.mark.write_disk +def test_write_async( + read_func: Callable[[Path], pl.DataFrame], + write_func: Callable[[pl.DataFrame, Path], None], + tmp_path: Path, +) -> None: + tmp_path.mkdir(exist_ok=True) + path = (tmp_path / "1").absolute() + path = f"file://{path}" # type: ignore[assignment] + + df = pl.DataFrame({"x": 1}) + + write_func(df, path) + + assert_frame_equal(read_func(path), df) + + +@pytest.mark.parametrize( + ("read_func", "write_func"), + READ_WRITE_FUNC_PARAM, +) +@pytest.mark.parametrize("opt_absolute_fn", [Path, Path.absolute]) +@pytest.mark.write_disk +def test_write_async_force_async( + read_func: Callable[[Path], pl.DataFrame], + write_func: Callable[[pl.DataFrame, Path], None], + opt_absolute_fn: Callable[[Path], Path], + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("POLARS_FORCE_ASYNC", "1") + tmp_path.mkdir(exist_ok=True) + path = opt_absolute_fn(tmp_path / "1") + + df = pl.DataFrame({"x": 1}) + + write_func(df, path) + + assert_frame_equal(read_func(path), df) diff --git a/py-polars/tests/unit/streaming/test_streaming_io.py b/py-polars/tests/unit/streaming/test_streaming_io.py index c304686f89d3..402fa5b6b86e 100644 --- a/py-polars/tests/unit/streaming/test_streaming_io.py +++ b/py-polars/tests/unit/streaming/test_streaming_io.py @@ -154,6 +154,9 @@ def test_sink_csv_with_options() -> None: null_value="BOOM", quote_style="always", maintain_order=False, + storage_options=None, + credential_provider="auto", + retries=2, ) ldf.optimization_toggle().sink_csv.assert_called_with( @@ -172,6 +175,9 @@ def test_sink_csv_with_options() -> None: null_value="BOOM", quote_style="always", maintain_order=False, + cloud_options=None, + credential_provider=None, + retries=2, )