From f3f2aaa9ad8efabdb2af73e7c0d4b9c67cffcc70 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Mon, 9 Dec 2024 01:21:58 -0800 Subject: [PATCH] feat(list): add fixed-size list support for value_counts Closes #3519 --- .../src/array/fixed_size_list_array.rs | 12 ++++++- src/daft-functions/src/list/value_counts.rs | 14 +++++--- tests/expressions/test_expressions.py | 35 +++++++++++++++++++ 3 files changed, 55 insertions(+), 6 deletions(-) diff --git a/src/daft-core/src/array/fixed_size_list_array.rs b/src/daft-core/src/array/fixed_size_list_array.rs index 29aeb2179a..8253107f07 100644 --- a/src/daft-core/src/array/fixed_size_list_array.rs +++ b/src/daft-core/src/array/fixed_size_list_array.rs @@ -198,8 +198,18 @@ impl FixedSizeListArray { } pub fn to_list(&self) -> ListArray { + let field = &self.field; + + let DataType::FixedSizeList(inner_type, _) = &field.dtype else { + unreachable!("Expected FixedSizeListArray, got {:?}", field.dtype); + }; + + let datatype = DataType::List(inner_type.clone()); + let mut field = (**field).clone(); + field.dtype = datatype; + ListArray::new( - self.field.clone(), + field, self.flat_child.clone(), self.generate_offsets(), self.validity.clone(), diff --git a/src/daft-functions/src/list/value_counts.rs b/src/daft-functions/src/list/value_counts.rs index 7f1d882749..c39baa08f8 100644 --- a/src/daft-functions/src/list/value_counts.rs +++ b/src/daft-functions/src/list/value_counts.rs @@ -29,11 +29,15 @@ impl ScalarUDF for ListValueCountsFunction { let data_field = data.to_field(schema)?; - let DataType::List(inner_type) = &data_field.dtype else { - return Err(DaftError::TypeError(format!( - "Expected list, got {}", - data_field.dtype - ))); + let inner_type = match &data_field.dtype { + DataType::List(inner_type) => inner_type, + DataType::FixedSizeList(inner_type, _) => inner_type, + _ => { + return Err(DaftError::TypeError(format!( + "Expected list or fixed size list, got {}", + data_field.dtype + ))); + } }; let map_type = DataType::Map { diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index 51fe60583d..723747f9be 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -602,6 +602,41 @@ def test_list_value_counts_nested(): ) +def test_list_value_counts_fixed_size(): + # Create data with lists of fixed size + data = { + "fixed_list": [ + [1, 2, 3], + [4, 3, 4], + [4, 5, 6], + [1, 2, 3], + [7, 8, 9], + None, + ] + } + + # Create DataFrame and cast the column to fixed size list + df = daft.from_pydict(data).with_column( + "fixed_list", daft.col("fixed_list").cast(DataType.fixed_size_list(DataType.int64(), 3)) + ) + + df = df.with_column("fixed_list", col("fixed_list").cast(DataType.fixed_size_list(DataType.int64(), 3))) + + # Get value counts + result = df.with_column("value_counts", col("fixed_list").list.value_counts()) + + # Verify the value counts + result_dict = result.to_pydict() + assert result_dict["value_counts"] == [ + [(1, 1), (2, 1), (3, 1)], + [(4, 2), (3, 1)], + [(4, 1), (5, 1), (6, 1)], + [(1, 1), (2, 1), (3, 1)], + [(7, 1), (8, 1), (9, 1)], + [], + ] + + def test_list_value_counts_degenerate(): import pyarrow as pa